[feature-wip](arrow-flight)(step3) Support authentication and user session (#24772)
This commit is contained in:
55
be/src/service/arrow_flight/auth_server_middleware.cpp
Normal file
55
be/src/service/arrow_flight/auth_server_middleware.cpp
Normal file
@ -0,0 +1,55 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include "service/arrow_flight/auth_server_middleware.h"
|
||||
|
||||
#include "service/arrow_flight/call_header_utils.h"
|
||||
|
||||
namespace doris {
|
||||
namespace flight {
|
||||
|
||||
void NoOpHeaderAuthServerMiddleware::SendingHeaders(
|
||||
arrow::flight::AddCallHeaders* outgoing_headers) {
|
||||
outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerDefaultToken);
|
||||
}
|
||||
|
||||
arrow::Status NoOpHeaderAuthServerMiddlewareFactory::StartCall(
|
||||
const arrow::flight::CallInfo& info, const arrow::flight::ServerCallContext& context,
|
||||
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) {
|
||||
std::string username, password;
|
||||
ParseBasicHeader(context.incoming_headers(), username, password);
|
||||
*middleware = std::make_shared<NoOpHeaderAuthServerMiddleware>();
|
||||
return arrow::Status::OK();
|
||||
}
|
||||
|
||||
void NoOpBearerAuthServerMiddleware::SendingHeaders(
|
||||
arrow::flight::AddCallHeaders* outgoing_headers) {
|
||||
std::string bearer_token =
|
||||
FindKeyValPrefixInCallHeaders(_incoming_headers, kAuthHeader, kBearerPrefix);
|
||||
*_is_valid = (bearer_token == std::string(kBearerDefaultToken));
|
||||
}
|
||||
|
||||
arrow::Status NoOpBearerAuthServerMiddlewareFactory::StartCall(
|
||||
const arrow::flight::CallInfo& info, const arrow::flight::ServerCallContext& context,
|
||||
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) {
|
||||
*middleware = std::make_shared<NoOpBearerAuthServerMiddleware>(context.incoming_headers(),
|
||||
&_is_valid);
|
||||
return arrow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace flight
|
||||
} // namespace doris
|
||||
88
be/src/service/arrow_flight/auth_server_middleware.h
Normal file
88
be/src/service/arrow_flight/auth_server_middleware.h
Normal file
@ -0,0 +1,88 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <arrow/status.h>
|
||||
|
||||
#include "arrow/flight/server.h"
|
||||
#include "arrow/flight/server_middleware.h"
|
||||
#include "arrow/flight/types.h"
|
||||
|
||||
namespace doris {
|
||||
namespace flight {
|
||||
|
||||
// Just return default bearer token.
|
||||
class NoOpHeaderAuthServerMiddleware : public arrow::flight::ServerMiddleware {
|
||||
public:
|
||||
void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override;
|
||||
|
||||
void CallCompleted(const arrow::Status& status) override {}
|
||||
|
||||
[[nodiscard]] std::string name() const override { return "NoOpHeaderAuthServerMiddleware"; }
|
||||
};
|
||||
|
||||
// Factory for base64 header authentication.
|
||||
// No actual authentication.
|
||||
class NoOpHeaderAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory {
|
||||
public:
|
||||
NoOpHeaderAuthServerMiddlewareFactory() = default;
|
||||
|
||||
arrow::Status StartCall(const arrow::flight::CallInfo& info,
|
||||
const arrow::flight::ServerCallContext& context,
|
||||
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) override;
|
||||
};
|
||||
|
||||
// A server middleware for validating incoming bearer header authentication.
|
||||
// Just compare with default bearer token.
|
||||
class NoOpBearerAuthServerMiddleware : public arrow::flight::ServerMiddleware {
|
||||
public:
|
||||
explicit NoOpBearerAuthServerMiddleware(const arrow::flight::CallHeaders& incoming_headers,
|
||||
bool* isValid)
|
||||
: _is_valid(isValid) {
|
||||
_incoming_headers = incoming_headers;
|
||||
}
|
||||
|
||||
void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override;
|
||||
|
||||
void CallCompleted(const arrow::Status& status) override {}
|
||||
|
||||
[[nodiscard]] std::string name() const override { return "NoOpBearerAuthServerMiddleware"; }
|
||||
|
||||
private:
|
||||
arrow::flight::CallHeaders _incoming_headers;
|
||||
bool* _is_valid;
|
||||
};
|
||||
|
||||
// Factory for base64 header authentication.
|
||||
// No actual authentication.
|
||||
class NoOpBearerAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory {
|
||||
public:
|
||||
NoOpBearerAuthServerMiddlewareFactory() : _is_valid(false) {}
|
||||
|
||||
arrow::Status StartCall(const arrow::flight::CallInfo& info,
|
||||
const arrow::flight::ServerCallContext& context,
|
||||
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) override;
|
||||
|
||||
[[nodiscard]] bool GetIsValid() const { return _is_valid; }
|
||||
|
||||
private:
|
||||
bool _is_valid;
|
||||
};
|
||||
|
||||
} // namespace flight
|
||||
} // namespace doris
|
||||
65
be/src/service/arrow_flight/call_header_utils.h
Normal file
65
be/src/service/arrow_flight/call_header_utils.h
Normal file
@ -0,0 +1,65 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "arrow/flight/types.h"
|
||||
#include "arrow/util/base64.h"
|
||||
|
||||
namespace doris {
|
||||
namespace flight {
|
||||
|
||||
const char kBearerDefaultToken[] = "bearertoken";
|
||||
const char kBasicPrefix[] = "Basic ";
|
||||
const char kBearerPrefix[] = "Bearer ";
|
||||
const char kAuthHeader[] = "authorization";
|
||||
|
||||
// Function to look in CallHeaders for a key that has a value starting with prefix and
|
||||
// return the rest of the value after the prefix.
|
||||
std::string FindKeyValPrefixInCallHeaders(const arrow::flight::CallHeaders& incoming_headers,
|
||||
const std::string& key, const std::string& prefix) {
|
||||
// Lambda function to compare characters without case sensitivity.
|
||||
auto char_compare = [](const char& char1, const char& char2) {
|
||||
return (::toupper(char1) == ::toupper(char2));
|
||||
};
|
||||
|
||||
auto iter = incoming_headers.find(key);
|
||||
if (iter == incoming_headers.end()) {
|
||||
return "";
|
||||
}
|
||||
const std::string val(iter->second);
|
||||
if (val.size() > prefix.length()) {
|
||||
if (std::equal(val.begin(), val.begin() + prefix.length(), prefix.begin(), char_compare)) {
|
||||
return val.substr(prefix.length());
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
void ParseBasicHeader(const arrow::flight::CallHeaders& incoming_headers, std::string& username,
|
||||
std::string& password) {
|
||||
std::string encoded_credentials =
|
||||
FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix);
|
||||
std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials));
|
||||
std::getline(decoded_stream, username, ':');
|
||||
std::getline(decoded_stream, password, ':');
|
||||
}
|
||||
|
||||
} // namespace flight
|
||||
} // namespace doris
|
||||
@ -74,7 +74,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
FlightSqlServer::FlightSqlServer(std::shared_ptr<Impl> impl) : impl_(std::move(impl)) {}
|
||||
FlightSqlServer::FlightSqlServer(std::shared_ptr<Impl> impl) : _impl(std::move(impl)) {}
|
||||
|
||||
arrow::Result<std::shared_ptr<FlightSqlServer>> FlightSqlServer::create() {
|
||||
std::shared_ptr<Impl> impl = std::make_shared<Impl>();
|
||||
@ -94,7 +94,7 @@ FlightSqlServer::~FlightSqlServer() {
|
||||
arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>> FlightSqlServer::DoGetStatement(
|
||||
const arrow::flight::ServerCallContext& context,
|
||||
const arrow::flight::sql::StatementQueryTicket& command) {
|
||||
return impl_->DoGetStatement(context, command);
|
||||
return _impl->DoGetStatement(context, command);
|
||||
}
|
||||
|
||||
Status FlightSqlServer::init(int port) {
|
||||
@ -108,6 +108,18 @@ Status FlightSqlServer::init(int port) {
|
||||
arrow::flight::Location::ForGrpcTcp(BackendOptions::get_service_bind_address(), port)
|
||||
.Value(&bind_location));
|
||||
arrow::flight::FlightServerOptions flight_options(bind_location);
|
||||
|
||||
// Not authenticated in BE flight server.
|
||||
// After the authentication between the ADBC Client and the FE flight server is completed,
|
||||
// the FE flight server will put the query id in the Ticket and send it back to the Client.
|
||||
// When the Client uses the Ticket to fetch data from the BE flight server, the BE flight
|
||||
// server will verify the query id, this step is equivalent to authentication.
|
||||
_header_middleware = std::make_shared<NoOpHeaderAuthServerMiddlewareFactory>();
|
||||
_bearer_middleware = std::make_shared<NoOpBearerAuthServerMiddlewareFactory>();
|
||||
flight_options.auth_handler = std::make_unique<arrow::flight::NoOpAuthHandler>();
|
||||
flight_options.middleware.push_back({"header-auth-server", _header_middleware});
|
||||
flight_options.middleware.push_back({"bearer-auth-server", _bearer_middleware});
|
||||
|
||||
RETURN_DORIS_STATUS_IF_ERROR(Init(flight_options));
|
||||
LOG(INFO) << "Arrow Flight Service bind to host: " << BackendOptions::get_service_bind_address()
|
||||
<< ", port: " << port;
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "arrow/result.h"
|
||||
#include "common/status.h"
|
||||
#include "service/arrow_flight/arrow_flight_batch_reader.h"
|
||||
#include "service/arrow_flight/auth_server_middleware.h"
|
||||
|
||||
namespace doris {
|
||||
namespace flight {
|
||||
@ -40,9 +41,12 @@ public:
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> impl_;
|
||||
std::shared_ptr<Impl> _impl;
|
||||
bool _inited = false;
|
||||
|
||||
std::shared_ptr<NoOpHeaderAuthServerMiddlewareFactory> _header_middleware;
|
||||
std::shared_ptr<NoOpBearerAuthServerMiddlewareFactory> _bearer_middleware;
|
||||
|
||||
explicit FlightSqlServer(std::shared_ptr<Impl> impl);
|
||||
};
|
||||
|
||||
|
||||
@ -390,7 +390,7 @@ public class Config extends ConfigBase {
|
||||
@ConfField(description = {"FE MySQL server 的端口号", "The port of FE MySQL server"})
|
||||
public static int query_port = 9030;
|
||||
|
||||
@ConfField(description = {"FE Arrow-Flight-SQL server 的端口号", "The port of FE Arrow-Flight-SQ server"})
|
||||
@ConfField(description = {"FE Arrow-Flight-SQL server 的端口号", "The port of FE Arrow-Flight-SQL server"})
|
||||
public static int arrow_flight_sql_port = -1;
|
||||
|
||||
@ConfField(description = {"MySQL 服务的 IO 线程数", "The number of IO threads in MySQL service"})
|
||||
@ -2211,4 +2211,15 @@ public class Config extends ConfigBase {
|
||||
"min buckets of auto bucket"
|
||||
})
|
||||
public static int autobucket_min_buckets = 1;
|
||||
|
||||
@ConfField(description = {"Arrow Flight Server中所有用户token的缓存上限,超过后LRU淘汰,默认值为2000",
|
||||
"The cache limit of all user tokens in Arrow Flight Server. which will be eliminated by"
|
||||
+ "LRU rules after exceeding the limit, the default value is 2000."})
|
||||
public static int arrow_flight_token_cache_size = 2000;
|
||||
|
||||
@ConfField(description = {"Arrow Flight Server中用户token的存活时间,自上次写入后过期时间,单位分钟,默认值为4320,即3天",
|
||||
"The alive time of the user token in Arrow Flight Server, expire after write, unit minutes,"
|
||||
+ "the default value is 4320, which is 3 days"})
|
||||
public static int arrow_flight_token_alive_time = 4320;
|
||||
|
||||
}
|
||||
|
||||
@ -759,6 +759,10 @@ under the License.
|
||||
<groupId>org.apache.arrow</groupId>
|
||||
<artifactId>flight-sql</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.immutables</groupId>
|
||||
<artifactId>value</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<repositories>
|
||||
<!-- for huawei obs sdk -->
|
||||
|
||||
@ -75,6 +75,12 @@ public class ConnectContext {
|
||||
|
||||
private static final String SSL_PROTOCOL = "TLS";
|
||||
|
||||
public enum ConnectType {
|
||||
MYSQL,
|
||||
ARROW_FLIGHT
|
||||
}
|
||||
|
||||
protected volatile ConnectType connectType;
|
||||
// set this id before analyze
|
||||
protected volatile long stmtId;
|
||||
protected volatile long forwardedStmtId;
|
||||
@ -90,6 +96,8 @@ public class ConnectContext {
|
||||
protected volatile int connectionId;
|
||||
// Timestamp when the connection is make
|
||||
protected volatile long loginTime;
|
||||
// arrow flight token
|
||||
protected volatile String peerIdentity;
|
||||
// mysql net
|
||||
protected volatile MysqlChannel mysqlChannel;
|
||||
// state
|
||||
@ -268,11 +276,31 @@ public class ConnectContext {
|
||||
return notEvalNondeterministicFunction;
|
||||
}
|
||||
|
||||
public ConnectType getConnectType() {
|
||||
return connectType;
|
||||
}
|
||||
|
||||
public ConnectContext() {
|
||||
this(null);
|
||||
this((StreamConnection) null);
|
||||
}
|
||||
|
||||
public ConnectContext(String peerIdentity) {
|
||||
this.connectType = ConnectType.ARROW_FLIGHT;
|
||||
this.peerIdentity = peerIdentity;
|
||||
state = new QueryState();
|
||||
returnRows = 0;
|
||||
isKilled = false;
|
||||
sessionVariable = VariableMgr.newSessionVariable();
|
||||
mysqlChannel = new DummyMysqlChannel();
|
||||
command = MysqlCommand.COM_SLEEP;
|
||||
if (Config.use_fuzzy_session_variable) {
|
||||
sessionVariable.initFuzzyModeVariables();
|
||||
}
|
||||
setResultSinkType(TResultSinkType.ARROW_FLIGHT_PROTOCAL);
|
||||
}
|
||||
|
||||
public ConnectContext(StreamConnection connection) {
|
||||
connectType = ConnectType.MYSQL;
|
||||
state = new QueryState();
|
||||
returnRows = 0;
|
||||
serverCapability = MysqlCapability.DEFAULT_CAPABILITY;
|
||||
@ -507,6 +535,10 @@ public class ConnectContext {
|
||||
this.loginTime = System.currentTimeMillis();
|
||||
}
|
||||
|
||||
public String getPeerIdentity() {
|
||||
return peerIdentity;
|
||||
}
|
||||
|
||||
public MysqlChannel getMysqlChannel() {
|
||||
return mysqlChannel;
|
||||
}
|
||||
@ -662,15 +694,28 @@ public class ConnectContext {
|
||||
this.resultSinkType = resultSinkType;
|
||||
}
|
||||
|
||||
public String getRemoteHostPortString() {
|
||||
if (connectType.equals(ConnectType.MYSQL)) {
|
||||
return getMysqlChannel().getRemoteHostPortString();
|
||||
} else if (connectType.equals(ConnectType.ARROW_FLIGHT)) {
|
||||
// TODO Get flight client IP:Port
|
||||
return peerIdentity;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
// kill operation with no protect.
|
||||
public void kill(boolean killConnection) {
|
||||
LOG.warn("kill query from {}, kill connection: {}", getMysqlChannel().getRemoteHostPortString(),
|
||||
killConnection);
|
||||
LOG.warn("kill query from {}, kill connection: {}", getRemoteHostPortString(), killConnection);
|
||||
|
||||
if (killConnection) {
|
||||
isKilled = true;
|
||||
// Close channel to break connection with client
|
||||
getMysqlChannel().close();
|
||||
if (connectType.equals(ConnectType.MYSQL)) {
|
||||
// Close channel to break connection with client
|
||||
getMysqlChannel().close();
|
||||
} else if (connectType.equals(ConnectType.ARROW_FLIGHT)) {
|
||||
connectScheduler.unregisterConnection(this);
|
||||
}
|
||||
}
|
||||
// Now, cancel running query.
|
||||
cancelQuery();
|
||||
@ -695,7 +740,7 @@ public class ConnectContext {
|
||||
if (delta > sessionVariable.getWaitTimeoutS() * 1000L) {
|
||||
// Need kill this connection.
|
||||
LOG.warn("kill wait timeout connection, remote: {}, wait timeout: {}",
|
||||
getMysqlChannel().getRemoteHostPortString(), sessionVariable.getWaitTimeoutS());
|
||||
getRemoteHostPortString(), sessionVariable.getWaitTimeoutS());
|
||||
|
||||
killFlag = true;
|
||||
killConnection = true;
|
||||
@ -706,11 +751,11 @@ public class ConnectContext {
|
||||
if (executor != null && executor.isInsertStmt()) {
|
||||
timeoutTag = "insert";
|
||||
}
|
||||
//to ms
|
||||
// to ms
|
||||
long timeout = getExecTimeout() * 1000L;
|
||||
if (delta > timeout) {
|
||||
LOG.warn("kill {} timeout, remote: {}, query timeout: {}",
|
||||
timeoutTag, getMysqlChannel().getRemoteHostPortString(), timeout);
|
||||
timeoutTag, getRemoteHostPortString(), timeout);
|
||||
killFlag = true;
|
||||
}
|
||||
}
|
||||
@ -791,7 +836,7 @@ public class ConnectContext {
|
||||
}
|
||||
row.add("" + connectionId);
|
||||
row.add(ClusterNamespace.getNameFromFullName(qualifiedUser));
|
||||
row.add(getMysqlChannel().getRemoteHostPortString());
|
||||
row.add(getRemoteHostPortString());
|
||||
row.add(TimeUtils.longToTimeString(loginTime));
|
||||
row.add(defaultCatalog);
|
||||
row.add(ClusterNamespace.getNameFromFullName(currentDb));
|
||||
|
||||
@ -21,6 +21,7 @@ import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.common.ThreadPoolManager;
|
||||
import org.apache.doris.common.util.DebugUtil;
|
||||
import org.apache.doris.mysql.privilege.PrivPredicate;
|
||||
import org.apache.doris.qe.ConnectContext.ConnectType;
|
||||
import org.apache.doris.thrift.TUniqueId;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
@ -45,6 +46,7 @@ public class ConnectScheduler {
|
||||
private final AtomicInteger nextConnectionId;
|
||||
private final Map<Integer, ConnectContext> connectionMap = Maps.newConcurrentMap();
|
||||
private final Map<String, AtomicInteger> connByUser = Maps.newConcurrentMap();
|
||||
private final Map<String, Integer> flightToken2ConnectionId = Maps.newConcurrentMap();
|
||||
|
||||
// valid trace id -> query id
|
||||
private final Map<String, TUniqueId> traceId2QueryId = Maps.newConcurrentMap();
|
||||
@ -100,6 +102,9 @@ public class ConnectScheduler {
|
||||
return false;
|
||||
}
|
||||
connectionMap.put(ctx.getConnectionId(), ctx);
|
||||
if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT)) {
|
||||
flightToken2ConnectionId.put(ctx.getPeerIdentity(), ctx.getConnectionId());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -111,6 +116,9 @@ public class ConnectScheduler {
|
||||
conns.decrementAndGet();
|
||||
}
|
||||
numberConnection.decrementAndGet();
|
||||
if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT)) {
|
||||
flightToken2ConnectionId.remove(ctx.getPeerIdentity());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -118,6 +126,14 @@ public class ConnectScheduler {
|
||||
return connectionMap.get(connectionId);
|
||||
}
|
||||
|
||||
public ConnectContext getContext(String flightToken) {
|
||||
if (flightToken2ConnectionId.containsKey(flightToken)) {
|
||||
int connectionId = flightToken2ConnectionId.get(flightToken);
|
||||
return getContext(connectionId);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public void cancelQuery(String queryId) {
|
||||
for (ConnectContext ctx : connectionMap.values()) {
|
||||
TUniqueId qid = ctx.queryId();
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
package org.apache.doris.qe;
|
||||
|
||||
import org.apache.doris.mysql.MysqlServer;
|
||||
import org.apache.doris.service.arrowflight.FlightSqlService;
|
||||
import org.apache.doris.service.arrowflight.DorisFlightSqlService;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
@ -35,7 +35,7 @@ public class QeService {
|
||||
private MysqlServer mysqlServer;
|
||||
|
||||
private int arrowFlightSQLPort;
|
||||
private FlightSqlService flightSqlService;
|
||||
private DorisFlightSqlService dorisFlightSqlService;
|
||||
|
||||
@Deprecated
|
||||
public QeService(int port, int arrowFlightSQLPort) {
|
||||
@ -63,8 +63,8 @@ public class QeService {
|
||||
System.exit(-1);
|
||||
}
|
||||
if (arrowFlightSQLPort != -1) {
|
||||
this.flightSqlService = new FlightSqlService(arrowFlightSQLPort);
|
||||
if (!flightSqlService.start()) {
|
||||
this.dorisFlightSqlService = new DorisFlightSqlService(arrowFlightSQLPort);
|
||||
if (!dorisFlightSqlService.start()) {
|
||||
System.exit(-1);
|
||||
}
|
||||
} else {
|
||||
|
||||
@ -22,6 +22,9 @@ package org.apache.doris.service.arrowflight;
|
||||
|
||||
import org.apache.doris.common.util.DebugUtil;
|
||||
import org.apache.doris.common.util.Util;
|
||||
import org.apache.doris.mysql.MysqlCommand;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.service.arrowflight.sessions.FlightSessionsManager;
|
||||
|
||||
import com.google.protobuf.Any;
|
||||
import com.google.protobuf.ByteString;
|
||||
@ -67,14 +70,16 @@ import org.apache.logging.log4j.Logger;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class FlightSqlServiceImpl implements FlightSqlProducer, AutoCloseable {
|
||||
private static final Logger LOG = LogManager.getLogger(FlightSqlServiceImpl.class);
|
||||
public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable {
|
||||
private static final Logger LOG = LogManager.getLogger(DorisFlightSqlProducer.class);
|
||||
private final Location location;
|
||||
private final BufferAllocator rootAllocator = new RootAllocator();
|
||||
private final SqlInfoBuilder sqlInfoBuilder;
|
||||
private final FlightSessionsManager flightSessionsManager;
|
||||
|
||||
public FlightSqlServiceImpl(final Location location) {
|
||||
public DorisFlightSqlProducer(final Location location, FlightSessionsManager flightSessionsManager) {
|
||||
this.location = location;
|
||||
this.flightSessionsManager = flightSessionsManager;
|
||||
sqlInfoBuilder = new SqlInfoBuilder();
|
||||
sqlInfoBuilder.withFlightSqlServerName("DorisFE")
|
||||
.withFlightSqlServerVersion("1.0")
|
||||
@ -103,9 +108,13 @@ public class FlightSqlServiceImpl implements FlightSqlProducer, AutoCloseable {
|
||||
@Override
|
||||
public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, final CallContext context,
|
||||
final FlightDescriptor descriptor) {
|
||||
ConnectContext connectContext = null;
|
||||
try {
|
||||
connectContext = flightSessionsManager.getConnectContext(context.peerIdentity());
|
||||
// Only for ConnectContext check timeout.
|
||||
connectContext.setCommand(MysqlCommand.COM_QUERY);
|
||||
final String query = request.getQuery();
|
||||
final FlightStatementExecutor flightStatementExecutor = new FlightStatementExecutor(query);
|
||||
final FlightStatementExecutor flightStatementExecutor = new FlightStatementExecutor(query, connectContext);
|
||||
|
||||
flightStatementExecutor.executeQuery();
|
||||
|
||||
@ -123,8 +132,13 @@ public class FlightSqlServiceImpl implements FlightSqlProducer, AutoCloseable {
|
||||
if (schema == null) {
|
||||
throw CallStatus.INTERNAL.withDescription("fetch arrow flight schema is null").toRuntimeException();
|
||||
}
|
||||
// TODO Set in BE callback after query end, Client client will not callback by default.
|
||||
connectContext.setCommand(MysqlCommand.COM_SLEEP);
|
||||
return new FlightInfo(schema, descriptor, endpoints, -1, -1);
|
||||
} catch (Exception e) {
|
||||
if (null != connectContext) {
|
||||
connectContext.setCommand(MysqlCommand.COM_SLEEP);
|
||||
}
|
||||
LOG.warn("get flight info statement failed, " + e.getMessage(), e);
|
||||
throw CallStatus.INTERNAL.withDescription(Util.getRootCauseMessage(e)).withCause(e).toRuntimeException();
|
||||
}
|
||||
@ -17,6 +17,13 @@
|
||||
|
||||
package org.apache.doris.service.arrowflight;
|
||||
|
||||
import org.apache.doris.common.Config;
|
||||
import org.apache.doris.service.arrowflight.auth2.FlightBearerTokenAuthenticator;
|
||||
import org.apache.doris.service.arrowflight.sessions.FlightSessionsManager;
|
||||
import org.apache.doris.service.arrowflight.sessions.FlightSessionsWithTokenManager;
|
||||
import org.apache.doris.service.arrowflight.tokens.FlightTokenManager;
|
||||
import org.apache.doris.service.arrowflight.tokens.FlightTokenManagerImpl;
|
||||
|
||||
import org.apache.arrow.flight.FlightServer;
|
||||
import org.apache.arrow.flight.Location;
|
||||
import org.apache.arrow.memory.BufferAllocator;
|
||||
@ -29,16 +36,23 @@ import java.io.IOException;
|
||||
/**
|
||||
* flight sql protocol implementation based on nio.
|
||||
*/
|
||||
public class FlightSqlService {
|
||||
private static final Logger LOG = LogManager.getLogger(FlightSqlService.class);
|
||||
public class DorisFlightSqlService {
|
||||
private static final Logger LOG = LogManager.getLogger(DorisFlightSqlService.class);
|
||||
private final FlightServer flightServer;
|
||||
private volatile boolean running;
|
||||
private final FlightTokenManager flightTokenManager;
|
||||
private final FlightSessionsManager flightSessionsManager;
|
||||
|
||||
public FlightSqlService(int port) {
|
||||
public DorisFlightSqlService(int port) {
|
||||
BufferAllocator allocator = new RootAllocator();
|
||||
Location location = Location.forGrpcInsecure("0.0.0.0", port);
|
||||
FlightSqlServiceImpl producer = new FlightSqlServiceImpl(location);
|
||||
flightServer = FlightServer.builder(allocator, location, producer).build();
|
||||
this.flightTokenManager = new FlightTokenManagerImpl(Config.arrow_flight_token_cache_size,
|
||||
Config.arrow_flight_token_alive_time);
|
||||
this.flightSessionsManager = new FlightSessionsWithTokenManager(flightTokenManager);
|
||||
|
||||
DorisFlightSqlProducer producer = new DorisFlightSqlProducer(location, flightSessionsManager);
|
||||
flightServer = FlightServer.builder(allocator, location, producer)
|
||||
.headerAuthenticator(new FlightBearerTokenAuthenticator(flightTokenManager)).build();
|
||||
}
|
||||
|
||||
// start Arrow Flight SQL service, return true if success, otherwise false
|
||||
@ -21,21 +21,15 @@
|
||||
package org.apache.doris.service.arrowflight;
|
||||
|
||||
import org.apache.doris.analysis.Expr;
|
||||
import org.apache.doris.analysis.UserIdentity;
|
||||
import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.common.Status;
|
||||
import org.apache.doris.common.util.DebugUtil;
|
||||
import org.apache.doris.proto.InternalService;
|
||||
import org.apache.doris.proto.Types;
|
||||
import org.apache.doris.qe.AutoCloseConnectContext;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.qe.SessionVariable;
|
||||
import org.apache.doris.qe.StmtExecutor;
|
||||
import org.apache.doris.rpc.BackendServiceProxy;
|
||||
import org.apache.doris.rpc.RpcException;
|
||||
import org.apache.doris.system.SystemInfoService;
|
||||
import org.apache.doris.thrift.TNetworkAddress;
|
||||
import org.apache.doris.thrift.TResultSinkType;
|
||||
import org.apache.doris.thrift.TStatusCode;
|
||||
import org.apache.doris.thrift.TUniqueId;
|
||||
|
||||
@ -55,8 +49,8 @@ import java.util.concurrent.Future;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
|
||||
public final class FlightStatementExecutor {
|
||||
private AutoCloseConnectContext acConnectContext;
|
||||
public final class FlightStatementExecutor implements AutoCloseable {
|
||||
private ConnectContext connectContext;
|
||||
private final String query;
|
||||
private TUniqueId queryId;
|
||||
private TUniqueId finstId;
|
||||
@ -64,9 +58,10 @@ public final class FlightStatementExecutor {
|
||||
private TNetworkAddress resultInternalServiceAddr;
|
||||
private ArrayList<Expr> resultOutputExprs;
|
||||
|
||||
public FlightStatementExecutor(final String query) {
|
||||
public FlightStatementExecutor(final String query, ConnectContext connectContext) {
|
||||
this.query = query;
|
||||
acConnectContext = buildConnectContext();
|
||||
this.connectContext = connectContext;
|
||||
connectContext.setThreadLocalInfo();
|
||||
}
|
||||
|
||||
public void setQueryId(TUniqueId queryId) {
|
||||
@ -126,29 +121,14 @@ public final class FlightStatementExecutor {
|
||||
return Objects.hash(this);
|
||||
}
|
||||
|
||||
public static AutoCloseConnectContext buildConnectContext() {
|
||||
ConnectContext connectContext = new ConnectContext();
|
||||
SessionVariable sessionVariable = connectContext.getSessionVariable();
|
||||
sessionVariable.internalSession = true;
|
||||
sessionVariable.setEnablePipelineEngine(false); // TODO
|
||||
sessionVariable.setEnablePipelineXEngine(false); // TODO
|
||||
connectContext.setEnv(Env.getCurrentEnv());
|
||||
connectContext.setQualifiedUser(UserIdentity.ROOT.getQualifiedUser()); // TODO
|
||||
connectContext.setCurrentUserIdentity(UserIdentity.ROOT); // TODO
|
||||
connectContext.setStartTime();
|
||||
connectContext.setCluster(SystemInfoService.DEFAULT_CLUSTER);
|
||||
connectContext.setResultSinkType(TResultSinkType.ARROW_FLIGHT_PROTOCAL);
|
||||
return new AutoCloseConnectContext(connectContext);
|
||||
}
|
||||
|
||||
public void executeQuery() {
|
||||
try {
|
||||
UUID uuid = UUID.randomUUID();
|
||||
TUniqueId queryId = new TUniqueId(uuid.getMostSignificantBits(), uuid.getLeastSignificantBits());
|
||||
setQueryId(queryId);
|
||||
acConnectContext.connectContext.setQueryId(queryId);
|
||||
StmtExecutor stmtExecutor = new StmtExecutor(acConnectContext.connectContext, getQuery());
|
||||
acConnectContext.connectContext.setExecutor(stmtExecutor);
|
||||
connectContext.setQueryId(queryId);
|
||||
StmtExecutor stmtExecutor = new StmtExecutor(connectContext, getQuery());
|
||||
connectContext.setExecutor(stmtExecutor);
|
||||
stmtExecutor.executeArrowFlightQuery(this);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to coord exec", e);
|
||||
@ -221,4 +201,9 @@ public final class FlightStatementExecutor {
|
||||
DebugUtil.printId(tid), address), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {
|
||||
ConnectContext.remove();
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,43 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
// This file is copied from
|
||||
|
||||
package org.apache.doris.service.arrowflight.auth2;
|
||||
|
||||
import org.apache.doris.analysis.UserIdentity;
|
||||
|
||||
import org.immutables.value.Value;
|
||||
|
||||
/**
|
||||
* Result of Authentication.
|
||||
*/
|
||||
@Value.Immutable
|
||||
public interface FlightAuthResult {
|
||||
String getUserName();
|
||||
|
||||
UserIdentity getUserIdentity();
|
||||
|
||||
String getRemoteIp();
|
||||
|
||||
static FlightAuthResult of(String userName, UserIdentity userIdentity, String remoteIp) {
|
||||
return ImmutableFlightAuthResult.builder()
|
||||
.userName(userName)
|
||||
.userIdentity(userIdentity)
|
||||
.remoteIp(remoteIp)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,75 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.service.arrowflight.auth2;
|
||||
|
||||
import org.apache.doris.analysis.UserIdentity;
|
||||
import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.common.AuthenticationException;
|
||||
import org.apache.doris.service.arrowflight.tokens.FlightTokenManager;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.apache.arrow.flight.CallStatus;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A collection of common Flight server authentication methods.
|
||||
*/
|
||||
public final class FlightAuthUtils {
|
||||
private FlightAuthUtils() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Authenticate against with the provided credentials.
|
||||
*
|
||||
* @param username username.
|
||||
* @param password password.
|
||||
* @param logger the slf4j logger for logging.
|
||||
* @throws org.apache.arrow.flight.FlightRuntimeException if unable to authenticate against
|
||||
* with the provided credentials.
|
||||
*/
|
||||
public static FlightAuthResult authenticateCredentials(String username, String password, String remoteIp,
|
||||
Logger logger) {
|
||||
try {
|
||||
List<UserIdentity> currentUserIdentity = Lists.newArrayList();
|
||||
|
||||
Env.getCurrentEnv().getAuth().checkPlainPassword(username, remoteIp, password, currentUserIdentity);
|
||||
Preconditions.checkState(currentUserIdentity.size() == 1);
|
||||
return FlightAuthResult.of(username, currentUserIdentity.get(0), remoteIp);
|
||||
} catch (AuthenticationException e) {
|
||||
logger.error("Unable to authenticate user {}", username, e);
|
||||
final String errMsg = "Unable to authenticate user " + username + ", exception: " + e.getMessage();
|
||||
throw CallStatus.UNAUTHENTICATED.withCause(e).withDescription(errMsg).toRuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new Bearer Token. Returns the bearer token associated with the User.
|
||||
*
|
||||
* @param flightTokenManager the TokenManager.
|
||||
* @param username the user to create a Flight server session for.
|
||||
* @param flightAuthResult the FlightAuthResult.
|
||||
* @return the token associated with the FlightTokenDetails created.
|
||||
*/
|
||||
public static String createToken(FlightTokenManager flightTokenManager, String username,
|
||||
FlightAuthResult flightAuthResult) {
|
||||
return flightTokenManager.createToken(username, flightAuthResult).getToken();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,114 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
// This file is copied from
|
||||
// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/flight/auth2/DremioBearerTokenAuthenticator.java
|
||||
// and modified by Doris
|
||||
|
||||
package org.apache.doris.service.arrowflight.auth2;
|
||||
|
||||
import org.apache.doris.service.arrowflight.tokens.FlightTokenManager;
|
||||
|
||||
import org.apache.arrow.flight.CallHeaders;
|
||||
import org.apache.arrow.flight.CallStatus;
|
||||
import org.apache.arrow.flight.auth2.Auth2Constants;
|
||||
import org.apache.arrow.flight.auth2.AuthUtilities;
|
||||
import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator;
|
||||
import org.apache.arrow.flight.auth2.CallHeaderAuthenticator;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
/**
|
||||
* Doris's custom implementation of CallHeaderAuthenticator for bearer token authentication.
|
||||
* This class implements CallHeaderAuthenticator rather than BearerTokenAuthenticator. Doris
|
||||
* creates FlightTokenDetails objects when the bearer token is created and requires access to the CallHeaders
|
||||
* in getAuthResultWithBearerToken.
|
||||
*/
|
||||
|
||||
public class FlightBearerTokenAuthenticator implements CallHeaderAuthenticator {
|
||||
private static final Logger LOG = LogManager.getLogger(FlightBearerTokenAuthenticator.class);
|
||||
|
||||
private final CallHeaderAuthenticator initialAuthenticator;
|
||||
private final FlightTokenManager flightTokenManager;
|
||||
|
||||
public FlightBearerTokenAuthenticator(FlightTokenManager flightTokenManager) {
|
||||
this.flightTokenManager = flightTokenManager;
|
||||
this.initialAuthenticator = new BasicCallHeaderAuthenticator(
|
||||
new FlightCredentialValidator(this.flightTokenManager));
|
||||
}
|
||||
|
||||
/**
|
||||
* If no bearer token is provided, the method initiates initial password and username
|
||||
* authentication. Once authenticated, client properties are retrieved from incoming CallHeaders.
|
||||
* Then it generates a token and creates a FlightTokenDetails with the retrieved client properties.
|
||||
* associated with it.
|
||||
* <p>
|
||||
* If a bearer token is provided, the method validates the provided token.
|
||||
*
|
||||
* @param incomingHeaders call headers to retrieve client properties and auth headers from.
|
||||
* @return an AuthResult with the bearer token and peer identity.
|
||||
*/
|
||||
@Override
|
||||
public AuthResult authenticate(CallHeaders incomingHeaders) {
|
||||
final String bearerToken = AuthUtilities.getValueFromAuthHeader(incomingHeaders,
|
||||
Auth2Constants.BEARER_PREFIX);
|
||||
|
||||
if (bearerToken != null) {
|
||||
return validateBearer(bearerToken);
|
||||
} else {
|
||||
final AuthResult result = initialAuthenticator.authenticate(incomingHeaders);
|
||||
return createAuthResultWithBearerToken(result.getPeerIdentity());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates provided token.
|
||||
*
|
||||
* @param token the token to validate.
|
||||
* @return an AuthResult with the bearer token and peer identity.
|
||||
*/
|
||||
AuthResult validateBearer(String token) {
|
||||
try {
|
||||
flightTokenManager.validateToken(token);
|
||||
return createAuthResultWithBearerToken(token);
|
||||
} catch (IllegalArgumentException e) {
|
||||
LOG.error("Bearer token validation failed.", e);
|
||||
throw CallStatus.UNAUTHENTICATED.toRuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Helper method to create an AuthResult.
|
||||
*
|
||||
* @param token the token to create a FlightTokenDetails for.
|
||||
* @return a new AuthResult with functionality to add given bearer token to the outgoing header.
|
||||
*/
|
||||
private AuthResult createAuthResultWithBearerToken(String token) {
|
||||
return new AuthResult() {
|
||||
@Override
|
||||
public void appendToOutgoingHeaders(CallHeaders outgoingHeaders) {
|
||||
outgoingHeaders.insert(Auth2Constants.AUTHORIZATION_HEADER,
|
||||
Auth2Constants.BEARER_PREFIX + token);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getPeerIdentity() {
|
||||
return token;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,70 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
// This file is copied from
|
||||
// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/flight/auth2/DremioCredentialValidator.java
|
||||
// and modified by Doris
|
||||
|
||||
package org.apache.doris.service.arrowflight.auth2;
|
||||
|
||||
import org.apache.doris.service.arrowflight.tokens.FlightTokenManager;
|
||||
|
||||
import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator;
|
||||
import org.apache.arrow.flight.auth2.CallHeaderAuthenticator.AuthResult;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
/**
|
||||
* Authentication specialized CredentialValidator implementation.
|
||||
*/
|
||||
public class FlightCredentialValidator implements BasicCallHeaderAuthenticator.CredentialValidator {
|
||||
private static final Logger LOG = LogManager.getLogger(FlightCredentialValidator.class);
|
||||
|
||||
private final FlightTokenManager flightTokenManager;
|
||||
|
||||
public FlightCredentialValidator(FlightTokenManager flightTokenManager) {
|
||||
this.flightTokenManager = flightTokenManager;
|
||||
}
|
||||
|
||||
/**
|
||||
* Authenticates against with the provided username and password.
|
||||
*
|
||||
* @param username username.
|
||||
* @param password user password.
|
||||
* @return AuthResult with username as the peer identity.
|
||||
*/
|
||||
@Override
|
||||
public AuthResult validate(String username, String password) {
|
||||
// TODO Add ClientAddress information while creating a Token
|
||||
String remoteIp = "0.0.0.0";
|
||||
FlightAuthResult flightAuthResult = FlightAuthUtils.authenticateCredentials(username, password, remoteIp, LOG);
|
||||
return getAuthResultWithBearerToken(flightAuthResult);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generates a bearer token, parses client properties from incoming headers, then creates a
|
||||
* FlightTokenDetails associated with the generated token and client properties.
|
||||
*
|
||||
* @param flightAuthResult the FlightAuthResult from initial authentication, with peer identity captured.
|
||||
* @return an FlightAuthResult with the bearer token and peer identity.
|
||||
*/
|
||||
AuthResult getAuthResultWithBearerToken(FlightAuthResult flightAuthResult) {
|
||||
final String username = flightAuthResult.getUserName();
|
||||
final String token = FlightAuthUtils.createToken(flightTokenManager, username, flightAuthResult);
|
||||
return () -> token;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,75 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
// This file is copied from
|
||||
|
||||
package org.apache.doris.service.arrowflight.sessions;
|
||||
|
||||
import org.apache.doris.analysis.UserIdentity;
|
||||
import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.common.ErrorCode;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.service.ExecuteEnv;
|
||||
import org.apache.doris.system.SystemInfoService;
|
||||
|
||||
import org.apache.arrow.flight.CallStatus;
|
||||
|
||||
/**
|
||||
* Manages Flight User Session ConnectContext.
|
||||
*/
|
||||
public interface FlightSessionsManager {
|
||||
|
||||
/**
|
||||
* Resolves an existing ConnectContext for the given peerIdentity.
|
||||
* <p>
|
||||
*
|
||||
* @param peerIdentity identity after authorization
|
||||
* @return The ConnectContext or null if no sessionId is given.
|
||||
*/
|
||||
ConnectContext getConnectContext(String peerIdentity);
|
||||
|
||||
/**
|
||||
* Creates a ConnectContext object and store it in the local cache, assuming that peerIdentity was already
|
||||
* validated.
|
||||
*
|
||||
* @param peerIdentity identity after authorization
|
||||
*/
|
||||
ConnectContext createConnectContext(String peerIdentity);
|
||||
|
||||
public static ConnectContext buildConnectContext(String peerIdentity, UserIdentity userIdentity, String remoteIP) {
|
||||
ConnectContext connectContext = new ConnectContext(peerIdentity);
|
||||
connectContext.setEnv(Env.getCurrentEnv());
|
||||
connectContext.setStartTime();
|
||||
connectContext.setCluster(SystemInfoService.DEFAULT_CLUSTER);
|
||||
connectContext.getSessionVariable().setEnablePipelineEngine(false); // TODO
|
||||
connectContext.getSessionVariable().setEnablePipelineXEngine(false); // TODO
|
||||
connectContext.setQualifiedUser(userIdentity.getQualifiedUser());
|
||||
connectContext.setCurrentUserIdentity(userIdentity);
|
||||
connectContext.setRemoteIP(remoteIP);
|
||||
connectContext.setUserQueryTimeout(
|
||||
connectContext.getEnv().getAuth().getQueryTimeout(connectContext.getQualifiedUser()));
|
||||
connectContext.setUserInsertTimeout(
|
||||
connectContext.getEnv().getAuth().getInsertTimeout(connectContext.getQualifiedUser()));
|
||||
|
||||
connectContext.setConnectScheduler(ExecuteEnv.getInstance().getScheduler());
|
||||
if (!ExecuteEnv.getInstance().getScheduler().registerConnection(connectContext)) {
|
||||
connectContext.getState().setError(ErrorCode.ERR_TOO_MANY_USER_CONNECTIONS,
|
||||
"Reach limit of connections");
|
||||
throw CallStatus.UNAUTHENTICATED.withDescription("Reach limit of connections").toRuntimeException();
|
||||
}
|
||||
return connectContext;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,68 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.service.arrowflight.sessions;
|
||||
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.service.ExecuteEnv;
|
||||
import org.apache.doris.service.arrowflight.tokens.FlightTokenDetails;
|
||||
import org.apache.doris.service.arrowflight.tokens.FlightTokenManager;
|
||||
|
||||
import org.apache.arrow.flight.CallStatus;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
public class FlightSessionsWithTokenManager implements FlightSessionsManager {
|
||||
private static final Logger LOG = LogManager.getLogger(FlightSessionsWithTokenManager.class);
|
||||
|
||||
private final FlightTokenManager flightTokenManager;
|
||||
|
||||
public FlightSessionsWithTokenManager(FlightTokenManager flightTokenManager) {
|
||||
this.flightTokenManager = flightTokenManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ConnectContext getConnectContext(String peerIdentity) {
|
||||
ConnectContext connectContext = ExecuteEnv.getInstance().getScheduler().getContext(peerIdentity);
|
||||
if (null == connectContext) {
|
||||
connectContext = createConnectContext(peerIdentity);
|
||||
if (null == connectContext) {
|
||||
flightTokenManager.invalidateToken(peerIdentity);
|
||||
String err = "UserSession expire after access, need reauthorize.";
|
||||
LOG.error(err);
|
||||
throw CallStatus.UNAUTHENTICATED.withDescription(err).toRuntimeException();
|
||||
}
|
||||
return connectContext;
|
||||
}
|
||||
return connectContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ConnectContext createConnectContext(String peerIdentity) {
|
||||
try {
|
||||
final FlightTokenDetails flightTokenDetails = flightTokenManager.validateToken(peerIdentity);
|
||||
if (flightTokenDetails.getCreatedSession()) {
|
||||
return null;
|
||||
}
|
||||
return FlightSessionsManager.buildConnectContext(peerIdentity, flightTokenDetails.getUserIdentity(),
|
||||
flightTokenDetails.getRemoteIp());
|
||||
} catch (IllegalArgumentException e) {
|
||||
LOG.error("Bearer token validation failed.", e);
|
||||
throw CallStatus.UNAUTHENTICATED.toRuntimeException();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,100 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.service.arrowflight.tokens;
|
||||
|
||||
import org.apache.doris.analysis.UserIdentity;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
/**
|
||||
* Details of a token.
|
||||
*/
|
||||
public final class FlightTokenDetails {
|
||||
|
||||
private final String token;
|
||||
private final String username;
|
||||
private final long issuedAt;
|
||||
private final long expiresAt;
|
||||
private final String remoteIp;
|
||||
private final UserIdentity userIdentity;
|
||||
private boolean createdSession = false;
|
||||
|
||||
public FlightTokenDetails() {
|
||||
this.token = "";
|
||||
this.username = "";
|
||||
this.issuedAt = 0;
|
||||
this.expiresAt = 0;
|
||||
this.remoteIp = "";
|
||||
this.userIdentity = new UserIdentity(username, remoteIp);
|
||||
}
|
||||
|
||||
public FlightTokenDetails(String token, String username, long issuedAt, long expiresAt, UserIdentity userIdentity,
|
||||
String remoteIp) {
|
||||
Preconditions.checkNotNull(token);
|
||||
Preconditions.checkNotNull(username);
|
||||
this.token = token;
|
||||
this.username = username;
|
||||
this.issuedAt = issuedAt;
|
||||
this.expiresAt = expiresAt;
|
||||
this.remoteIp = remoteIp;
|
||||
this.userIdentity = userIdentity;
|
||||
}
|
||||
|
||||
public FlightTokenDetails(String token, String username, long expiresAt) {
|
||||
Preconditions.checkNotNull(token);
|
||||
Preconditions.checkNotNull(username);
|
||||
this.token = token;
|
||||
this.username = username;
|
||||
this.expiresAt = expiresAt;
|
||||
this.issuedAt = 0;
|
||||
this.remoteIp = "";
|
||||
this.userIdentity = new UserIdentity(username, remoteIp);
|
||||
}
|
||||
|
||||
public String getToken() {
|
||||
return token;
|
||||
}
|
||||
|
||||
public String getUsername() {
|
||||
return username;
|
||||
}
|
||||
|
||||
public long getIssuedAt() {
|
||||
return issuedAt;
|
||||
}
|
||||
|
||||
public long getExpiresAt() {
|
||||
return expiresAt;
|
||||
}
|
||||
|
||||
public String getRemoteIp() {
|
||||
return remoteIp;
|
||||
}
|
||||
|
||||
public UserIdentity getUserIdentity() {
|
||||
return userIdentity;
|
||||
}
|
||||
|
||||
public void setCreatedSession(boolean createdSession) {
|
||||
this.createdSession = createdSession;
|
||||
}
|
||||
|
||||
public boolean getCreatedSession() {
|
||||
return createdSession;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,61 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/tokens/TokenManager.java
|
||||
// and modified by Doris
|
||||
|
||||
package org.apache.doris.service.arrowflight.tokens;
|
||||
|
||||
import org.apache.doris.service.arrowflight.auth2.FlightAuthResult;
|
||||
|
||||
/**
|
||||
* Token manager.
|
||||
*/
|
||||
public interface FlightTokenManager extends AutoCloseable {
|
||||
|
||||
/**
|
||||
* Generate a securely random token.
|
||||
*
|
||||
* @return a token string
|
||||
*/
|
||||
String newToken();
|
||||
|
||||
/**
|
||||
* Create a token for the session, and return details about the token.
|
||||
*
|
||||
* @param username user name
|
||||
* @param flightAuthResult auth result
|
||||
* @return token details
|
||||
*/
|
||||
FlightTokenDetails createToken(String username, FlightAuthResult flightAuthResult);
|
||||
|
||||
/**
|
||||
* Validate the token, and return details about the token.
|
||||
*
|
||||
* @param token session token
|
||||
* @return token details
|
||||
* @throws IllegalArgumentException if the token is invalid or expired
|
||||
*/
|
||||
FlightTokenDetails validateToken(String token) throws IllegalArgumentException;
|
||||
|
||||
/**
|
||||
* Invalidate the token.
|
||||
*
|
||||
* @param token session token
|
||||
*/
|
||||
void invalidateToken(String token);
|
||||
|
||||
}
|
||||
@ -0,0 +1,118 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/tokens/TokenManagerImpl.java
|
||||
// and modified by Doris
|
||||
|
||||
package org.apache.doris.service.arrowflight.tokens;
|
||||
|
||||
import org.apache.doris.service.arrowflight.auth2.FlightAuthResult;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import com.google.common.cache.CacheLoader;
|
||||
import com.google.common.cache.LoadingCache;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Token manager implementation.
|
||||
*/
|
||||
public class FlightTokenManagerImpl implements FlightTokenManager {
|
||||
private static final Logger LOG = LogManager.getLogger(FlightTokenManagerImpl.class);
|
||||
|
||||
private final SecureRandom generator = new SecureRandom();
|
||||
private final int cacheExpiration;
|
||||
|
||||
private LoadingCache<String, FlightTokenDetails> tokenCache;
|
||||
|
||||
public FlightTokenManagerImpl(final int cacheSize, final int cacheExpiration) {
|
||||
this.cacheExpiration = cacheExpiration;
|
||||
|
||||
this.tokenCache = CacheBuilder.newBuilder()
|
||||
.maximumSize(cacheSize)
|
||||
.expireAfterWrite(cacheExpiration, TimeUnit.MINUTES)
|
||||
.build(new CacheLoader<String, FlightTokenDetails>() {
|
||||
@Override
|
||||
public FlightTokenDetails load(String key) {
|
||||
return new FlightTokenDetails();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// From https://stackoverflow.com/questions/41107/how-to-generate-a-random-alpha-numeric-string
|
||||
// ... This works by choosing 130 bits from a cryptographically secure random bit generator, and encoding
|
||||
// them in base-32. 128 bits is considered to be cryptographically strong, but each digit in a base 32
|
||||
// number can encode 5 bits, so 128 is rounded up to the next multiple of 5 ... Why 32? Because 32 = 2^5;
|
||||
// each character will represent exactly 5 bits, and 130 bits can be evenly divided into characters.
|
||||
@Override
|
||||
public String newToken() {
|
||||
return new BigInteger(130, generator).toString(32);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FlightTokenDetails createToken(final String username, final FlightAuthResult flightAuthResult) {
|
||||
final String token = newToken();
|
||||
final long now = System.currentTimeMillis();
|
||||
final long expires = now + TimeUnit.MILLISECONDS.convert(cacheExpiration, TimeUnit.MINUTES);
|
||||
|
||||
final FlightTokenDetails flightTokenDetails = new FlightTokenDetails(token, username, now, expires,
|
||||
flightAuthResult.getUserIdentity(), flightAuthResult.getRemoteIp());
|
||||
|
||||
tokenCache.put(token, flightTokenDetails);
|
||||
LOG.trace("Created flight token for user: {}", username);
|
||||
return flightTokenDetails;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FlightTokenDetails validateToken(final String token) throws IllegalArgumentException {
|
||||
final FlightTokenDetails value = getTokenDetails(token);
|
||||
if (System.currentTimeMillis() >= value.getExpiresAt()) {
|
||||
tokenCache.invalidate(token); // removes from the store as well
|
||||
throw new IllegalArgumentException("token expired");
|
||||
}
|
||||
|
||||
LOG.trace("Validated flight token for user: {}", value.getUsername());
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void invalidateToken(final String token) {
|
||||
LOG.trace("Invalidate flight token, {}", token);
|
||||
tokenCache.invalidate(token); // removes from the store as well
|
||||
}
|
||||
|
||||
private FlightTokenDetails getTokenDetails(final String token) {
|
||||
Preconditions.checkNotNull(token, "invalid token");
|
||||
final FlightTokenDetails value;
|
||||
try {
|
||||
value = tokenCache.getUnchecked(token);
|
||||
} catch (CacheLoader.InvalidCacheLoadException ignored) {
|
||||
throw new IllegalArgumentException("invalid token");
|
||||
}
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {
|
||||
tokenCache.invalidateAll();
|
||||
}
|
||||
}
|
||||
@ -314,6 +314,7 @@ under the License.
|
||||
<woodstox.version>6.5.1</woodstox.version>
|
||||
<kerby.version>2.0.3</kerby.version>
|
||||
<jettison.version>1.5.4</jettison.version>
|
||||
<immutables.version>2.9.3</immutables.version>
|
||||
<vesoft.client.version>3.0.0</vesoft.client.version>
|
||||
<!--todo waiting release-->
|
||||
<quartz.version>2.3.2</quartz.version>
|
||||
@ -1524,6 +1525,12 @@ under the License.
|
||||
<artifactId>arrow-jdbc</artifactId>
|
||||
<version>${arrow.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.immutables</groupId>
|
||||
<artifactId>value</artifactId>
|
||||
<version>${immutables.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
<dependencies>
|
||||
|
||||
Reference in New Issue
Block a user