[feature-wip](arrow-flight)(step3) Support authentication and user session (#24772)

This commit is contained in:
Xinyi Zou
2023-09-27 14:53:58 +08:00
committed by GitHub
parent 26818de9c8
commit 87a30dc41d
23 changed files with 1098 additions and 54 deletions

View File

@ -0,0 +1,55 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "service/arrow_flight/auth_server_middleware.h"
#include "service/arrow_flight/call_header_utils.h"
namespace doris {
namespace flight {
void NoOpHeaderAuthServerMiddleware::SendingHeaders(
arrow::flight::AddCallHeaders* outgoing_headers) {
outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerDefaultToken);
}
arrow::Status NoOpHeaderAuthServerMiddlewareFactory::StartCall(
const arrow::flight::CallInfo& info, const arrow::flight::ServerCallContext& context,
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) {
std::string username, password;
ParseBasicHeader(context.incoming_headers(), username, password);
*middleware = std::make_shared<NoOpHeaderAuthServerMiddleware>();
return arrow::Status::OK();
}
void NoOpBearerAuthServerMiddleware::SendingHeaders(
arrow::flight::AddCallHeaders* outgoing_headers) {
std::string bearer_token =
FindKeyValPrefixInCallHeaders(_incoming_headers, kAuthHeader, kBearerPrefix);
*_is_valid = (bearer_token == std::string(kBearerDefaultToken));
}
arrow::Status NoOpBearerAuthServerMiddlewareFactory::StartCall(
const arrow::flight::CallInfo& info, const arrow::flight::ServerCallContext& context,
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) {
*middleware = std::make_shared<NoOpBearerAuthServerMiddleware>(context.incoming_headers(),
&_is_valid);
return arrow::Status::OK();
}
} // namespace flight
} // namespace doris

View File

@ -0,0 +1,88 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <arrow/status.h>
#include "arrow/flight/server.h"
#include "arrow/flight/server_middleware.h"
#include "arrow/flight/types.h"
namespace doris {
namespace flight {
// Just return default bearer token.
class NoOpHeaderAuthServerMiddleware : public arrow::flight::ServerMiddleware {
public:
void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override;
void CallCompleted(const arrow::Status& status) override {}
[[nodiscard]] std::string name() const override { return "NoOpHeaderAuthServerMiddleware"; }
};
// Factory for base64 header authentication.
// No actual authentication.
class NoOpHeaderAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory {
public:
NoOpHeaderAuthServerMiddlewareFactory() = default;
arrow::Status StartCall(const arrow::flight::CallInfo& info,
const arrow::flight::ServerCallContext& context,
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) override;
};
// A server middleware for validating incoming bearer header authentication.
// Just compare with default bearer token.
class NoOpBearerAuthServerMiddleware : public arrow::flight::ServerMiddleware {
public:
explicit NoOpBearerAuthServerMiddleware(const arrow::flight::CallHeaders& incoming_headers,
bool* isValid)
: _is_valid(isValid) {
_incoming_headers = incoming_headers;
}
void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override;
void CallCompleted(const arrow::Status& status) override {}
[[nodiscard]] std::string name() const override { return "NoOpBearerAuthServerMiddleware"; }
private:
arrow::flight::CallHeaders _incoming_headers;
bool* _is_valid;
};
// Factory for base64 header authentication.
// No actual authentication.
class NoOpBearerAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory {
public:
NoOpBearerAuthServerMiddlewareFactory() : _is_valid(false) {}
arrow::Status StartCall(const arrow::flight::CallInfo& info,
const arrow::flight::ServerCallContext& context,
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) override;
[[nodiscard]] bool GetIsValid() const { return _is_valid; }
private:
bool _is_valid;
};
} // namespace flight
} // namespace doris

View File

@ -0,0 +1,65 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <sstream>
#include "arrow/flight/types.h"
#include "arrow/util/base64.h"
namespace doris {
namespace flight {
const char kBearerDefaultToken[] = "bearertoken";
const char kBasicPrefix[] = "Basic ";
const char kBearerPrefix[] = "Bearer ";
const char kAuthHeader[] = "authorization";
// Function to look in CallHeaders for a key that has a value starting with prefix and
// return the rest of the value after the prefix.
std::string FindKeyValPrefixInCallHeaders(const arrow::flight::CallHeaders& incoming_headers,
const std::string& key, const std::string& prefix) {
// Lambda function to compare characters without case sensitivity.
auto char_compare = [](const char& char1, const char& char2) {
return (::toupper(char1) == ::toupper(char2));
};
auto iter = incoming_headers.find(key);
if (iter == incoming_headers.end()) {
return "";
}
const std::string val(iter->second);
if (val.size() > prefix.length()) {
if (std::equal(val.begin(), val.begin() + prefix.length(), prefix.begin(), char_compare)) {
return val.substr(prefix.length());
}
}
return "";
}
void ParseBasicHeader(const arrow::flight::CallHeaders& incoming_headers, std::string& username,
std::string& password) {
std::string encoded_credentials =
FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix);
std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials));
std::getline(decoded_stream, username, ':');
std::getline(decoded_stream, password, ':');
}
} // namespace flight
} // namespace doris

View File

@ -74,7 +74,7 @@ public:
}
};
FlightSqlServer::FlightSqlServer(std::shared_ptr<Impl> impl) : impl_(std::move(impl)) {}
FlightSqlServer::FlightSqlServer(std::shared_ptr<Impl> impl) : _impl(std::move(impl)) {}
arrow::Result<std::shared_ptr<FlightSqlServer>> FlightSqlServer::create() {
std::shared_ptr<Impl> impl = std::make_shared<Impl>();
@ -94,7 +94,7 @@ FlightSqlServer::~FlightSqlServer() {
arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>> FlightSqlServer::DoGetStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementQueryTicket& command) {
return impl_->DoGetStatement(context, command);
return _impl->DoGetStatement(context, command);
}
Status FlightSqlServer::init(int port) {
@ -108,6 +108,18 @@ Status FlightSqlServer::init(int port) {
arrow::flight::Location::ForGrpcTcp(BackendOptions::get_service_bind_address(), port)
.Value(&bind_location));
arrow::flight::FlightServerOptions flight_options(bind_location);
// Not authenticated in BE flight server.
// After the authentication between the ADBC Client and the FE flight server is completed,
// the FE flight server will put the query id in the Ticket and send it back to the Client.
// When the Client uses the Ticket to fetch data from the BE flight server, the BE flight
// server will verify the query id, this step is equivalent to authentication.
_header_middleware = std::make_shared<NoOpHeaderAuthServerMiddlewareFactory>();
_bearer_middleware = std::make_shared<NoOpBearerAuthServerMiddlewareFactory>();
flight_options.auth_handler = std::make_unique<arrow::flight::NoOpAuthHandler>();
flight_options.middleware.push_back({"header-auth-server", _header_middleware});
flight_options.middleware.push_back({"bearer-auth-server", _bearer_middleware});
RETURN_DORIS_STATUS_IF_ERROR(Init(flight_options));
LOG(INFO) << "Arrow Flight Service bind to host: " << BackendOptions::get_service_bind_address()
<< ", port: " << port;

View File

@ -21,6 +21,7 @@
#include "arrow/result.h"
#include "common/status.h"
#include "service/arrow_flight/arrow_flight_batch_reader.h"
#include "service/arrow_flight/auth_server_middleware.h"
namespace doris {
namespace flight {
@ -40,9 +41,12 @@ public:
private:
class Impl;
std::shared_ptr<Impl> impl_;
std::shared_ptr<Impl> _impl;
bool _inited = false;
std::shared_ptr<NoOpHeaderAuthServerMiddlewareFactory> _header_middleware;
std::shared_ptr<NoOpBearerAuthServerMiddlewareFactory> _bearer_middleware;
explicit FlightSqlServer(std::shared_ptr<Impl> impl);
};

View File

@ -390,7 +390,7 @@ public class Config extends ConfigBase {
@ConfField(description = {"FE MySQL server 的端口号", "The port of FE MySQL server"})
public static int query_port = 9030;
@ConfField(description = {"FE Arrow-Flight-SQL server 的端口号", "The port of FE Arrow-Flight-SQ server"})
@ConfField(description = {"FE Arrow-Flight-SQL server 的端口号", "The port of FE Arrow-Flight-SQL server"})
public static int arrow_flight_sql_port = -1;
@ConfField(description = {"MySQL 服务的 IO 线程数", "The number of IO threads in MySQL service"})
@ -2211,4 +2211,15 @@ public class Config extends ConfigBase {
"min buckets of auto bucket"
})
public static int autobucket_min_buckets = 1;
@ConfField(description = {"Arrow Flight Server中所有用户token的缓存上限,超过后LRU淘汰,默认值为2000",
"The cache limit of all user tokens in Arrow Flight Server. which will be eliminated by"
+ "LRU rules after exceeding the limit, the default value is 2000."})
public static int arrow_flight_token_cache_size = 2000;
@ConfField(description = {"Arrow Flight Server中用户token的存活时间,自上次写入后过期时间,单位分钟,默认值为4320,即3天",
"The alive time of the user token in Arrow Flight Server, expire after write, unit minutes,"
+ "the default value is 4320, which is 3 days"})
public static int arrow_flight_token_alive_time = 4320;
}

View File

@ -759,6 +759,10 @@ under the License.
<groupId>org.apache.arrow</groupId>
<artifactId>flight-sql</artifactId>
</dependency>
<dependency>
<groupId>org.immutables</groupId>
<artifactId>value</artifactId>
</dependency>
</dependencies>
<repositories>
<!-- for huawei obs sdk -->

View File

@ -75,6 +75,12 @@ public class ConnectContext {
private static final String SSL_PROTOCOL = "TLS";
public enum ConnectType {
MYSQL,
ARROW_FLIGHT
}
protected volatile ConnectType connectType;
// set this id before analyze
protected volatile long stmtId;
protected volatile long forwardedStmtId;
@ -90,6 +96,8 @@ public class ConnectContext {
protected volatile int connectionId;
// Timestamp when the connection is make
protected volatile long loginTime;
// arrow flight token
protected volatile String peerIdentity;
// mysql net
protected volatile MysqlChannel mysqlChannel;
// state
@ -268,11 +276,31 @@ public class ConnectContext {
return notEvalNondeterministicFunction;
}
public ConnectType getConnectType() {
return connectType;
}
public ConnectContext() {
this(null);
this((StreamConnection) null);
}
public ConnectContext(String peerIdentity) {
this.connectType = ConnectType.ARROW_FLIGHT;
this.peerIdentity = peerIdentity;
state = new QueryState();
returnRows = 0;
isKilled = false;
sessionVariable = VariableMgr.newSessionVariable();
mysqlChannel = new DummyMysqlChannel();
command = MysqlCommand.COM_SLEEP;
if (Config.use_fuzzy_session_variable) {
sessionVariable.initFuzzyModeVariables();
}
setResultSinkType(TResultSinkType.ARROW_FLIGHT_PROTOCAL);
}
public ConnectContext(StreamConnection connection) {
connectType = ConnectType.MYSQL;
state = new QueryState();
returnRows = 0;
serverCapability = MysqlCapability.DEFAULT_CAPABILITY;
@ -507,6 +535,10 @@ public class ConnectContext {
this.loginTime = System.currentTimeMillis();
}
public String getPeerIdentity() {
return peerIdentity;
}
public MysqlChannel getMysqlChannel() {
return mysqlChannel;
}
@ -662,15 +694,28 @@ public class ConnectContext {
this.resultSinkType = resultSinkType;
}
public String getRemoteHostPortString() {
if (connectType.equals(ConnectType.MYSQL)) {
return getMysqlChannel().getRemoteHostPortString();
} else if (connectType.equals(ConnectType.ARROW_FLIGHT)) {
// TODO Get flight client IP:Port
return peerIdentity;
}
return "";
}
// kill operation with no protect.
public void kill(boolean killConnection) {
LOG.warn("kill query from {}, kill connection: {}", getMysqlChannel().getRemoteHostPortString(),
killConnection);
LOG.warn("kill query from {}, kill connection: {}", getRemoteHostPortString(), killConnection);
if (killConnection) {
isKilled = true;
// Close channel to break connection with client
getMysqlChannel().close();
if (connectType.equals(ConnectType.MYSQL)) {
// Close channel to break connection with client
getMysqlChannel().close();
} else if (connectType.equals(ConnectType.ARROW_FLIGHT)) {
connectScheduler.unregisterConnection(this);
}
}
// Now, cancel running query.
cancelQuery();
@ -695,7 +740,7 @@ public class ConnectContext {
if (delta > sessionVariable.getWaitTimeoutS() * 1000L) {
// Need kill this connection.
LOG.warn("kill wait timeout connection, remote: {}, wait timeout: {}",
getMysqlChannel().getRemoteHostPortString(), sessionVariable.getWaitTimeoutS());
getRemoteHostPortString(), sessionVariable.getWaitTimeoutS());
killFlag = true;
killConnection = true;
@ -706,11 +751,11 @@ public class ConnectContext {
if (executor != null && executor.isInsertStmt()) {
timeoutTag = "insert";
}
//to ms
// to ms
long timeout = getExecTimeout() * 1000L;
if (delta > timeout) {
LOG.warn("kill {} timeout, remote: {}, query timeout: {}",
timeoutTag, getMysqlChannel().getRemoteHostPortString(), timeout);
timeoutTag, getRemoteHostPortString(), timeout);
killFlag = true;
}
}
@ -791,7 +836,7 @@ public class ConnectContext {
}
row.add("" + connectionId);
row.add(ClusterNamespace.getNameFromFullName(qualifiedUser));
row.add(getMysqlChannel().getRemoteHostPortString());
row.add(getRemoteHostPortString());
row.add(TimeUtils.longToTimeString(loginTime));
row.add(defaultCatalog);
row.add(ClusterNamespace.getNameFromFullName(currentDb));

View File

@ -21,6 +21,7 @@ import org.apache.doris.catalog.Env;
import org.apache.doris.common.ThreadPoolManager;
import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.qe.ConnectContext.ConnectType;
import org.apache.doris.thrift.TUniqueId;
import com.google.common.collect.Lists;
@ -45,6 +46,7 @@ public class ConnectScheduler {
private final AtomicInteger nextConnectionId;
private final Map<Integer, ConnectContext> connectionMap = Maps.newConcurrentMap();
private final Map<String, AtomicInteger> connByUser = Maps.newConcurrentMap();
private final Map<String, Integer> flightToken2ConnectionId = Maps.newConcurrentMap();
// valid trace id -> query id
private final Map<String, TUniqueId> traceId2QueryId = Maps.newConcurrentMap();
@ -100,6 +102,9 @@ public class ConnectScheduler {
return false;
}
connectionMap.put(ctx.getConnectionId(), ctx);
if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT)) {
flightToken2ConnectionId.put(ctx.getPeerIdentity(), ctx.getConnectionId());
}
return true;
}
@ -111,6 +116,9 @@ public class ConnectScheduler {
conns.decrementAndGet();
}
numberConnection.decrementAndGet();
if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT)) {
flightToken2ConnectionId.remove(ctx.getPeerIdentity());
}
}
}
@ -118,6 +126,14 @@ public class ConnectScheduler {
return connectionMap.get(connectionId);
}
public ConnectContext getContext(String flightToken) {
if (flightToken2ConnectionId.containsKey(flightToken)) {
int connectionId = flightToken2ConnectionId.get(flightToken);
return getContext(connectionId);
}
return null;
}
public void cancelQuery(String queryId) {
for (ConnectContext ctx : connectionMap.values()) {
TUniqueId qid = ctx.queryId();

View File

@ -18,7 +18,7 @@
package org.apache.doris.qe;
import org.apache.doris.mysql.MysqlServer;
import org.apache.doris.service.arrowflight.FlightSqlService;
import org.apache.doris.service.arrowflight.DorisFlightSqlService;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@ -35,7 +35,7 @@ public class QeService {
private MysqlServer mysqlServer;
private int arrowFlightSQLPort;
private FlightSqlService flightSqlService;
private DorisFlightSqlService dorisFlightSqlService;
@Deprecated
public QeService(int port, int arrowFlightSQLPort) {
@ -63,8 +63,8 @@ public class QeService {
System.exit(-1);
}
if (arrowFlightSQLPort != -1) {
this.flightSqlService = new FlightSqlService(arrowFlightSQLPort);
if (!flightSqlService.start()) {
this.dorisFlightSqlService = new DorisFlightSqlService(arrowFlightSQLPort);
if (!dorisFlightSqlService.start()) {
System.exit(-1);
}
} else {

View File

@ -22,6 +22,9 @@ package org.apache.doris.service.arrowflight;
import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.common.util.Util;
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.service.arrowflight.sessions.FlightSessionsManager;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
@ -67,14 +70,16 @@ import org.apache.logging.log4j.Logger;
import java.util.Collections;
import java.util.List;
public class FlightSqlServiceImpl implements FlightSqlProducer, AutoCloseable {
private static final Logger LOG = LogManager.getLogger(FlightSqlServiceImpl.class);
public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable {
private static final Logger LOG = LogManager.getLogger(DorisFlightSqlProducer.class);
private final Location location;
private final BufferAllocator rootAllocator = new RootAllocator();
private final SqlInfoBuilder sqlInfoBuilder;
private final FlightSessionsManager flightSessionsManager;
public FlightSqlServiceImpl(final Location location) {
public DorisFlightSqlProducer(final Location location, FlightSessionsManager flightSessionsManager) {
this.location = location;
this.flightSessionsManager = flightSessionsManager;
sqlInfoBuilder = new SqlInfoBuilder();
sqlInfoBuilder.withFlightSqlServerName("DorisFE")
.withFlightSqlServerVersion("1.0")
@ -103,9 +108,13 @@ public class FlightSqlServiceImpl implements FlightSqlProducer, AutoCloseable {
@Override
public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, final CallContext context,
final FlightDescriptor descriptor) {
ConnectContext connectContext = null;
try {
connectContext = flightSessionsManager.getConnectContext(context.peerIdentity());
// Only for ConnectContext check timeout.
connectContext.setCommand(MysqlCommand.COM_QUERY);
final String query = request.getQuery();
final FlightStatementExecutor flightStatementExecutor = new FlightStatementExecutor(query);
final FlightStatementExecutor flightStatementExecutor = new FlightStatementExecutor(query, connectContext);
flightStatementExecutor.executeQuery();
@ -123,8 +132,13 @@ public class FlightSqlServiceImpl implements FlightSqlProducer, AutoCloseable {
if (schema == null) {
throw CallStatus.INTERNAL.withDescription("fetch arrow flight schema is null").toRuntimeException();
}
// TODO Set in BE callback after query end, Client client will not callback by default.
connectContext.setCommand(MysqlCommand.COM_SLEEP);
return new FlightInfo(schema, descriptor, endpoints, -1, -1);
} catch (Exception e) {
if (null != connectContext) {
connectContext.setCommand(MysqlCommand.COM_SLEEP);
}
LOG.warn("get flight info statement failed, " + e.getMessage(), e);
throw CallStatus.INTERNAL.withDescription(Util.getRootCauseMessage(e)).withCause(e).toRuntimeException();
}

View File

@ -17,6 +17,13 @@
package org.apache.doris.service.arrowflight;
import org.apache.doris.common.Config;
import org.apache.doris.service.arrowflight.auth2.FlightBearerTokenAuthenticator;
import org.apache.doris.service.arrowflight.sessions.FlightSessionsManager;
import org.apache.doris.service.arrowflight.sessions.FlightSessionsWithTokenManager;
import org.apache.doris.service.arrowflight.tokens.FlightTokenManager;
import org.apache.doris.service.arrowflight.tokens.FlightTokenManagerImpl;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.Location;
import org.apache.arrow.memory.BufferAllocator;
@ -29,16 +36,23 @@ import java.io.IOException;
/**
* flight sql protocol implementation based on nio.
*/
public class FlightSqlService {
private static final Logger LOG = LogManager.getLogger(FlightSqlService.class);
public class DorisFlightSqlService {
private static final Logger LOG = LogManager.getLogger(DorisFlightSqlService.class);
private final FlightServer flightServer;
private volatile boolean running;
private final FlightTokenManager flightTokenManager;
private final FlightSessionsManager flightSessionsManager;
public FlightSqlService(int port) {
public DorisFlightSqlService(int port) {
BufferAllocator allocator = new RootAllocator();
Location location = Location.forGrpcInsecure("0.0.0.0", port);
FlightSqlServiceImpl producer = new FlightSqlServiceImpl(location);
flightServer = FlightServer.builder(allocator, location, producer).build();
this.flightTokenManager = new FlightTokenManagerImpl(Config.arrow_flight_token_cache_size,
Config.arrow_flight_token_alive_time);
this.flightSessionsManager = new FlightSessionsWithTokenManager(flightTokenManager);
DorisFlightSqlProducer producer = new DorisFlightSqlProducer(location, flightSessionsManager);
flightServer = FlightServer.builder(allocator, location, producer)
.headerAuthenticator(new FlightBearerTokenAuthenticator(flightTokenManager)).build();
}
// start Arrow Flight SQL service, return true if success, otherwise false

View File

@ -21,21 +21,15 @@
package org.apache.doris.service.arrowflight;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.Status;
import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.proto.InternalService;
import org.apache.doris.proto.Types;
import org.apache.doris.qe.AutoCloseConnectContext;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.qe.StmtExecutor;
import org.apache.doris.rpc.BackendServiceProxy;
import org.apache.doris.rpc.RpcException;
import org.apache.doris.system.SystemInfoService;
import org.apache.doris.thrift.TNetworkAddress;
import org.apache.doris.thrift.TResultSinkType;
import org.apache.doris.thrift.TStatusCode;
import org.apache.doris.thrift.TUniqueId;
@ -55,8 +49,8 @@ import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
public final class FlightStatementExecutor {
private AutoCloseConnectContext acConnectContext;
public final class FlightStatementExecutor implements AutoCloseable {
private ConnectContext connectContext;
private final String query;
private TUniqueId queryId;
private TUniqueId finstId;
@ -64,9 +58,10 @@ public final class FlightStatementExecutor {
private TNetworkAddress resultInternalServiceAddr;
private ArrayList<Expr> resultOutputExprs;
public FlightStatementExecutor(final String query) {
public FlightStatementExecutor(final String query, ConnectContext connectContext) {
this.query = query;
acConnectContext = buildConnectContext();
this.connectContext = connectContext;
connectContext.setThreadLocalInfo();
}
public void setQueryId(TUniqueId queryId) {
@ -126,29 +121,14 @@ public final class FlightStatementExecutor {
return Objects.hash(this);
}
public static AutoCloseConnectContext buildConnectContext() {
ConnectContext connectContext = new ConnectContext();
SessionVariable sessionVariable = connectContext.getSessionVariable();
sessionVariable.internalSession = true;
sessionVariable.setEnablePipelineEngine(false); // TODO
sessionVariable.setEnablePipelineXEngine(false); // TODO
connectContext.setEnv(Env.getCurrentEnv());
connectContext.setQualifiedUser(UserIdentity.ROOT.getQualifiedUser()); // TODO
connectContext.setCurrentUserIdentity(UserIdentity.ROOT); // TODO
connectContext.setStartTime();
connectContext.setCluster(SystemInfoService.DEFAULT_CLUSTER);
connectContext.setResultSinkType(TResultSinkType.ARROW_FLIGHT_PROTOCAL);
return new AutoCloseConnectContext(connectContext);
}
public void executeQuery() {
try {
UUID uuid = UUID.randomUUID();
TUniqueId queryId = new TUniqueId(uuid.getMostSignificantBits(), uuid.getLeastSignificantBits());
setQueryId(queryId);
acConnectContext.connectContext.setQueryId(queryId);
StmtExecutor stmtExecutor = new StmtExecutor(acConnectContext.connectContext, getQuery());
acConnectContext.connectContext.setExecutor(stmtExecutor);
connectContext.setQueryId(queryId);
StmtExecutor stmtExecutor = new StmtExecutor(connectContext, getQuery());
connectContext.setExecutor(stmtExecutor);
stmtExecutor.executeArrowFlightQuery(this);
} catch (Exception e) {
throw new RuntimeException("Failed to coord exec", e);
@ -221,4 +201,9 @@ public final class FlightStatementExecutor {
DebugUtil.printId(tid), address), e);
}
}
@Override
public void close() throws Exception {
ConnectContext.remove();
}
}

View File

@ -0,0 +1,43 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
package org.apache.doris.service.arrowflight.auth2;
import org.apache.doris.analysis.UserIdentity;
import org.immutables.value.Value;
/**
* Result of Authentication.
*/
@Value.Immutable
public interface FlightAuthResult {
String getUserName();
UserIdentity getUserIdentity();
String getRemoteIp();
static FlightAuthResult of(String userName, UserIdentity userIdentity, String remoteIp) {
return ImmutableFlightAuthResult.builder()
.userName(userName)
.userIdentity(userIdentity)
.remoteIp(remoteIp)
.build();
}
}

View File

@ -0,0 +1,75 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.service.arrowflight.auth2;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.AuthenticationException;
import org.apache.doris.service.arrowflight.tokens.FlightTokenManager;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.apache.arrow.flight.CallStatus;
import org.apache.logging.log4j.Logger;
import java.util.List;
/**
* A collection of common Flight server authentication methods.
*/
public final class FlightAuthUtils {
private FlightAuthUtils() {
}
/**
* Authenticate against with the provided credentials.
*
* @param username username.
* @param password password.
* @param logger the slf4j logger for logging.
* @throws org.apache.arrow.flight.FlightRuntimeException if unable to authenticate against
* with the provided credentials.
*/
public static FlightAuthResult authenticateCredentials(String username, String password, String remoteIp,
Logger logger) {
try {
List<UserIdentity> currentUserIdentity = Lists.newArrayList();
Env.getCurrentEnv().getAuth().checkPlainPassword(username, remoteIp, password, currentUserIdentity);
Preconditions.checkState(currentUserIdentity.size() == 1);
return FlightAuthResult.of(username, currentUserIdentity.get(0), remoteIp);
} catch (AuthenticationException e) {
logger.error("Unable to authenticate user {}", username, e);
final String errMsg = "Unable to authenticate user " + username + ", exception: " + e.getMessage();
throw CallStatus.UNAUTHENTICATED.withCause(e).withDescription(errMsg).toRuntimeException();
}
}
/**
* Creates a new Bearer Token. Returns the bearer token associated with the User.
*
* @param flightTokenManager the TokenManager.
* @param username the user to create a Flight server session for.
* @param flightAuthResult the FlightAuthResult.
* @return the token associated with the FlightTokenDetails created.
*/
public static String createToken(FlightTokenManager flightTokenManager, String username,
FlightAuthResult flightAuthResult) {
return flightTokenManager.createToken(username, flightAuthResult).getToken();
}
}

View File

@ -0,0 +1,114 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/flight/auth2/DremioBearerTokenAuthenticator.java
// and modified by Doris
package org.apache.doris.service.arrowflight.auth2;
import org.apache.doris.service.arrowflight.tokens.FlightTokenManager;
import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.auth2.Auth2Constants;
import org.apache.arrow.flight.auth2.AuthUtilities;
import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator;
import org.apache.arrow.flight.auth2.CallHeaderAuthenticator;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
/**
* Doris's custom implementation of CallHeaderAuthenticator for bearer token authentication.
* This class implements CallHeaderAuthenticator rather than BearerTokenAuthenticator. Doris
* creates FlightTokenDetails objects when the bearer token is created and requires access to the CallHeaders
* in getAuthResultWithBearerToken.
*/
public class FlightBearerTokenAuthenticator implements CallHeaderAuthenticator {
private static final Logger LOG = LogManager.getLogger(FlightBearerTokenAuthenticator.class);
private final CallHeaderAuthenticator initialAuthenticator;
private final FlightTokenManager flightTokenManager;
public FlightBearerTokenAuthenticator(FlightTokenManager flightTokenManager) {
this.flightTokenManager = flightTokenManager;
this.initialAuthenticator = new BasicCallHeaderAuthenticator(
new FlightCredentialValidator(this.flightTokenManager));
}
/**
* If no bearer token is provided, the method initiates initial password and username
* authentication. Once authenticated, client properties are retrieved from incoming CallHeaders.
* Then it generates a token and creates a FlightTokenDetails with the retrieved client properties.
* associated with it.
* <p>
* If a bearer token is provided, the method validates the provided token.
*
* @param incomingHeaders call headers to retrieve client properties and auth headers from.
* @return an AuthResult with the bearer token and peer identity.
*/
@Override
public AuthResult authenticate(CallHeaders incomingHeaders) {
final String bearerToken = AuthUtilities.getValueFromAuthHeader(incomingHeaders,
Auth2Constants.BEARER_PREFIX);
if (bearerToken != null) {
return validateBearer(bearerToken);
} else {
final AuthResult result = initialAuthenticator.authenticate(incomingHeaders);
return createAuthResultWithBearerToken(result.getPeerIdentity());
}
}
/**
* Validates provided token.
*
* @param token the token to validate.
* @return an AuthResult with the bearer token and peer identity.
*/
AuthResult validateBearer(String token) {
try {
flightTokenManager.validateToken(token);
return createAuthResultWithBearerToken(token);
} catch (IllegalArgumentException e) {
LOG.error("Bearer token validation failed.", e);
throw CallStatus.UNAUTHENTICATED.toRuntimeException();
}
}
/**
* Helper method to create an AuthResult.
*
* @param token the token to create a FlightTokenDetails for.
* @return a new AuthResult with functionality to add given bearer token to the outgoing header.
*/
private AuthResult createAuthResultWithBearerToken(String token) {
return new AuthResult() {
@Override
public void appendToOutgoingHeaders(CallHeaders outgoingHeaders) {
outgoingHeaders.insert(Auth2Constants.AUTHORIZATION_HEADER,
Auth2Constants.BEARER_PREFIX + token);
}
@Override
public String getPeerIdentity() {
return token;
}
};
}
}

View File

@ -0,0 +1,70 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/flight/auth2/DremioCredentialValidator.java
// and modified by Doris
package org.apache.doris.service.arrowflight.auth2;
import org.apache.doris.service.arrowflight.tokens.FlightTokenManager;
import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator;
import org.apache.arrow.flight.auth2.CallHeaderAuthenticator.AuthResult;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
/**
* Authentication specialized CredentialValidator implementation.
*/
public class FlightCredentialValidator implements BasicCallHeaderAuthenticator.CredentialValidator {
private static final Logger LOG = LogManager.getLogger(FlightCredentialValidator.class);
private final FlightTokenManager flightTokenManager;
public FlightCredentialValidator(FlightTokenManager flightTokenManager) {
this.flightTokenManager = flightTokenManager;
}
/**
* Authenticates against with the provided username and password.
*
* @param username username.
* @param password user password.
* @return AuthResult with username as the peer identity.
*/
@Override
public AuthResult validate(String username, String password) {
// TODO Add ClientAddress information while creating a Token
String remoteIp = "0.0.0.0";
FlightAuthResult flightAuthResult = FlightAuthUtils.authenticateCredentials(username, password, remoteIp, LOG);
return getAuthResultWithBearerToken(flightAuthResult);
}
/**
* Generates a bearer token, parses client properties from incoming headers, then creates a
* FlightTokenDetails associated with the generated token and client properties.
*
* @param flightAuthResult the FlightAuthResult from initial authentication, with peer identity captured.
* @return an FlightAuthResult with the bearer token and peer identity.
*/
AuthResult getAuthResultWithBearerToken(FlightAuthResult flightAuthResult) {
final String username = flightAuthResult.getUserName();
final String token = FlightAuthUtils.createToken(flightTokenManager, username, flightAuthResult);
return () -> token;
}
}

View File

@ -0,0 +1,75 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
package org.apache.doris.service.arrowflight.sessions;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.ErrorCode;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.service.ExecuteEnv;
import org.apache.doris.system.SystemInfoService;
import org.apache.arrow.flight.CallStatus;
/**
* Manages Flight User Session ConnectContext.
*/
public interface FlightSessionsManager {
/**
* Resolves an existing ConnectContext for the given peerIdentity.
* <p>
*
* @param peerIdentity identity after authorization
* @return The ConnectContext or null if no sessionId is given.
*/
ConnectContext getConnectContext(String peerIdentity);
/**
* Creates a ConnectContext object and store it in the local cache, assuming that peerIdentity was already
* validated.
*
* @param peerIdentity identity after authorization
*/
ConnectContext createConnectContext(String peerIdentity);
public static ConnectContext buildConnectContext(String peerIdentity, UserIdentity userIdentity, String remoteIP) {
ConnectContext connectContext = new ConnectContext(peerIdentity);
connectContext.setEnv(Env.getCurrentEnv());
connectContext.setStartTime();
connectContext.setCluster(SystemInfoService.DEFAULT_CLUSTER);
connectContext.getSessionVariable().setEnablePipelineEngine(false); // TODO
connectContext.getSessionVariable().setEnablePipelineXEngine(false); // TODO
connectContext.setQualifiedUser(userIdentity.getQualifiedUser());
connectContext.setCurrentUserIdentity(userIdentity);
connectContext.setRemoteIP(remoteIP);
connectContext.setUserQueryTimeout(
connectContext.getEnv().getAuth().getQueryTimeout(connectContext.getQualifiedUser()));
connectContext.setUserInsertTimeout(
connectContext.getEnv().getAuth().getInsertTimeout(connectContext.getQualifiedUser()));
connectContext.setConnectScheduler(ExecuteEnv.getInstance().getScheduler());
if (!ExecuteEnv.getInstance().getScheduler().registerConnection(connectContext)) {
connectContext.getState().setError(ErrorCode.ERR_TOO_MANY_USER_CONNECTIONS,
"Reach limit of connections");
throw CallStatus.UNAUTHENTICATED.withDescription("Reach limit of connections").toRuntimeException();
}
return connectContext;
}
}

View File

@ -0,0 +1,68 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.service.arrowflight.sessions;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.service.ExecuteEnv;
import org.apache.doris.service.arrowflight.tokens.FlightTokenDetails;
import org.apache.doris.service.arrowflight.tokens.FlightTokenManager;
import org.apache.arrow.flight.CallStatus;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
public class FlightSessionsWithTokenManager implements FlightSessionsManager {
private static final Logger LOG = LogManager.getLogger(FlightSessionsWithTokenManager.class);
private final FlightTokenManager flightTokenManager;
public FlightSessionsWithTokenManager(FlightTokenManager flightTokenManager) {
this.flightTokenManager = flightTokenManager;
}
@Override
public ConnectContext getConnectContext(String peerIdentity) {
ConnectContext connectContext = ExecuteEnv.getInstance().getScheduler().getContext(peerIdentity);
if (null == connectContext) {
connectContext = createConnectContext(peerIdentity);
if (null == connectContext) {
flightTokenManager.invalidateToken(peerIdentity);
String err = "UserSession expire after access, need reauthorize.";
LOG.error(err);
throw CallStatus.UNAUTHENTICATED.withDescription(err).toRuntimeException();
}
return connectContext;
}
return connectContext;
}
@Override
public ConnectContext createConnectContext(String peerIdentity) {
try {
final FlightTokenDetails flightTokenDetails = flightTokenManager.validateToken(peerIdentity);
if (flightTokenDetails.getCreatedSession()) {
return null;
}
return FlightSessionsManager.buildConnectContext(peerIdentity, flightTokenDetails.getUserIdentity(),
flightTokenDetails.getRemoteIp());
} catch (IllegalArgumentException e) {
LOG.error("Bearer token validation failed.", e);
throw CallStatus.UNAUTHENTICATED.toRuntimeException();
}
}
}

View File

@ -0,0 +1,100 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.service.arrowflight.tokens;
import org.apache.doris.analysis.UserIdentity;
import com.google.common.base.Preconditions;
/**
* Details of a token.
*/
public final class FlightTokenDetails {
private final String token;
private final String username;
private final long issuedAt;
private final long expiresAt;
private final String remoteIp;
private final UserIdentity userIdentity;
private boolean createdSession = false;
public FlightTokenDetails() {
this.token = "";
this.username = "";
this.issuedAt = 0;
this.expiresAt = 0;
this.remoteIp = "";
this.userIdentity = new UserIdentity(username, remoteIp);
}
public FlightTokenDetails(String token, String username, long issuedAt, long expiresAt, UserIdentity userIdentity,
String remoteIp) {
Preconditions.checkNotNull(token);
Preconditions.checkNotNull(username);
this.token = token;
this.username = username;
this.issuedAt = issuedAt;
this.expiresAt = expiresAt;
this.remoteIp = remoteIp;
this.userIdentity = userIdentity;
}
public FlightTokenDetails(String token, String username, long expiresAt) {
Preconditions.checkNotNull(token);
Preconditions.checkNotNull(username);
this.token = token;
this.username = username;
this.expiresAt = expiresAt;
this.issuedAt = 0;
this.remoteIp = "";
this.userIdentity = new UserIdentity(username, remoteIp);
}
public String getToken() {
return token;
}
public String getUsername() {
return username;
}
public long getIssuedAt() {
return issuedAt;
}
public long getExpiresAt() {
return expiresAt;
}
public String getRemoteIp() {
return remoteIp;
}
public UserIdentity getUserIdentity() {
return userIdentity;
}
public void setCreatedSession(boolean createdSession) {
this.createdSession = createdSession;
}
public boolean getCreatedSession() {
return createdSession;
}
}

View File

@ -0,0 +1,61 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/tokens/TokenManager.java
// and modified by Doris
package org.apache.doris.service.arrowflight.tokens;
import org.apache.doris.service.arrowflight.auth2.FlightAuthResult;
/**
* Token manager.
*/
public interface FlightTokenManager extends AutoCloseable {
/**
* Generate a securely random token.
*
* @return a token string
*/
String newToken();
/**
* Create a token for the session, and return details about the token.
*
* @param username user name
* @param flightAuthResult auth result
* @return token details
*/
FlightTokenDetails createToken(String username, FlightAuthResult flightAuthResult);
/**
* Validate the token, and return details about the token.
*
* @param token session token
* @return token details
* @throws IllegalArgumentException if the token is invalid or expired
*/
FlightTokenDetails validateToken(String token) throws IllegalArgumentException;
/**
* Invalidate the token.
*
* @param token session token
*/
void invalidateToken(String token);
}

View File

@ -0,0 +1,118 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/tokens/TokenManagerImpl.java
// and modified by Doris
package org.apache.doris.service.arrowflight.tokens;
import org.apache.doris.service.arrowflight.auth2.FlightAuthResult;
import com.google.common.base.Preconditions;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.math.BigInteger;
import java.security.SecureRandom;
import java.util.concurrent.TimeUnit;
/**
* Token manager implementation.
*/
public class FlightTokenManagerImpl implements FlightTokenManager {
private static final Logger LOG = LogManager.getLogger(FlightTokenManagerImpl.class);
private final SecureRandom generator = new SecureRandom();
private final int cacheExpiration;
private LoadingCache<String, FlightTokenDetails> tokenCache;
public FlightTokenManagerImpl(final int cacheSize, final int cacheExpiration) {
this.cacheExpiration = cacheExpiration;
this.tokenCache = CacheBuilder.newBuilder()
.maximumSize(cacheSize)
.expireAfterWrite(cacheExpiration, TimeUnit.MINUTES)
.build(new CacheLoader<String, FlightTokenDetails>() {
@Override
public FlightTokenDetails load(String key) {
return new FlightTokenDetails();
}
});
}
// From https://stackoverflow.com/questions/41107/how-to-generate-a-random-alpha-numeric-string
// ... This works by choosing 130 bits from a cryptographically secure random bit generator, and encoding
// them in base-32. 128 bits is considered to be cryptographically strong, but each digit in a base 32
// number can encode 5 bits, so 128 is rounded up to the next multiple of 5 ... Why 32? Because 32 = 2^5;
// each character will represent exactly 5 bits, and 130 bits can be evenly divided into characters.
@Override
public String newToken() {
return new BigInteger(130, generator).toString(32);
}
@Override
public FlightTokenDetails createToken(final String username, final FlightAuthResult flightAuthResult) {
final String token = newToken();
final long now = System.currentTimeMillis();
final long expires = now + TimeUnit.MILLISECONDS.convert(cacheExpiration, TimeUnit.MINUTES);
final FlightTokenDetails flightTokenDetails = new FlightTokenDetails(token, username, now, expires,
flightAuthResult.getUserIdentity(), flightAuthResult.getRemoteIp());
tokenCache.put(token, flightTokenDetails);
LOG.trace("Created flight token for user: {}", username);
return flightTokenDetails;
}
@Override
public FlightTokenDetails validateToken(final String token) throws IllegalArgumentException {
final FlightTokenDetails value = getTokenDetails(token);
if (System.currentTimeMillis() >= value.getExpiresAt()) {
tokenCache.invalidate(token); // removes from the store as well
throw new IllegalArgumentException("token expired");
}
LOG.trace("Validated flight token for user: {}", value.getUsername());
return value;
}
@Override
public void invalidateToken(final String token) {
LOG.trace("Invalidate flight token, {}", token);
tokenCache.invalidate(token); // removes from the store as well
}
private FlightTokenDetails getTokenDetails(final String token) {
Preconditions.checkNotNull(token, "invalid token");
final FlightTokenDetails value;
try {
value = tokenCache.getUnchecked(token);
} catch (CacheLoader.InvalidCacheLoadException ignored) {
throw new IllegalArgumentException("invalid token");
}
return value;
}
@Override
public void close() throws Exception {
tokenCache.invalidateAll();
}
}

View File

@ -314,6 +314,7 @@ under the License.
<woodstox.version>6.5.1</woodstox.version>
<kerby.version>2.0.3</kerby.version>
<jettison.version>1.5.4</jettison.version>
<immutables.version>2.9.3</immutables.version>
<vesoft.client.version>3.0.0</vesoft.client.version>
<!--todo waiting release-->
<quartz.version>2.3.2</quartz.version>
@ -1524,6 +1525,12 @@ under the License.
<artifactId>arrow-jdbc</artifactId>
<version>${arrow.version}</version>
</dependency>
<dependency>
<groupId>org.immutables</groupId>
<artifactId>value</artifactId>
<version>${immutables.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>