branch-2.1: [fix](arrow-flight-sql) Separate arrow-flight-sql connection and mysql connection (#51110)

This commit is contained in:
Xinyi Zou
2025-05-22 20:39:48 +08:00
committed by GitHub
parent df464f84b1
commit fc9beac468
15 changed files with 421 additions and 172 deletions

View File

@ -2569,21 +2569,23 @@ public class Config extends ConfigBase {
})
public static int autobucket_max_buckets = 128;
@ConfField(description = {"Arrow Flight Server中所有用户token的缓存上限,超过后LRU淘汰,默认值为512, "
+ "并强制限制小于 qe_max_connection/2, 避免`Reach limit of connections`, "
+ "因为arrow flight sql是无状态的协议,连接通常不会主动断开,"
+ "bearer token 从 cache 淘汰的同时会 unregister Connection.",
"The cache limit of all user tokens in Arrow Flight Server. which will be eliminated by"
+ "LRU rules after exceeding the limit, the default value is 512, the mandatory limit is "
+ "less than qe_max_connection/2 to avoid `Reach limit of connections`, "
+ "because arrow flight sql is a stateless protocol, the connection is usually not actively "
+ "disconnected, bearer token is evict from the cache will unregister ConnectContext."})
public static int arrow_flight_token_cache_size = 512;
@ConfField(description = {"单个 FE 的 Arrow Flight Server 的最大连接数。",
"Maximal number of connections of Arrow Flight Server per FE."})
public static int arrow_flight_max_connections = 4096;
@ConfField(description = {"Arrow Flight Server中用户token的存活时间,自上次写入后过期时间,单位分钟,默认值为4320,即3天",
"The alive time of the user token in Arrow Flight Server, expire after write, unit minutes,"
+ "the default value is 4320, which is 3 days"})
public static int arrow_flight_token_alive_time = 4320;
@ConfField(description = {"(已弃用,被 arrow_flight_max_connection 替代) Arrow Flight Server中所有用户token的缓存上限,"
+ "超过后LRU淘汰, arrow flight sql是无状态的协议,连接通常不会主动断开,"
+ "bearer token 从 cache 淘汰的同时会 unregister Connection.",
"(Deprecated, replaced by arrow_flight_max_connection) The cache limit of all user tokens in "
+ "Arrow Flight Server. which will be eliminated by LRU rules after exceeding the limit, "
+ "arrow flight sql is a stateless protocol, the connection is usually not actively disconnected, "
+ "bearer token is evict from the cache will unregister ConnectContext."})
public static int arrow_flight_token_cache_size = 4096;
@ConfField(description = {"Arrow Flight Server中用户token的存活时间,自上次写入后过期时间,单位秒,默认值为86400,即1天",
"The alive time of the user token in Arrow Flight Server, expire after write, unit second,"
+ "the default value is 86400, which is 1 days"})
public static int arrow_flight_token_alive_time_second = 86400;
@ConfField(mutable = true, description = {
"Doris 为了兼用 mysql 周边工具生态,会内置一个名为 mysql 的数据库,如果该数据库与用户自建数据库冲突,"

View File

@ -89,16 +89,17 @@ public class AcceptListener implements ChannelListener<AcceptingChannel<StreamCo
if (!MysqlProto.negotiate(context)) {
throw new AfterConnectedException("mysql negotiate failed");
}
int res = connectScheduler.registerConnection(context);
int res = connectScheduler.getConnectPoolMgr().registerConnection(context);
if (res == -1) {
MysqlProto.sendResponsePacket(context);
connection.setCloseListener(
streamConnection -> connectScheduler.unregisterConnection(context));
streamConnection -> connectScheduler.getConnectPoolMgr()
.unregisterConnection(context));
} else {
long userConnLimit = context.getEnv().getAuth().getMaxConn(context.getQualifiedUser());
String errMsg = String.format(
"Reach limit of connections. Total: %d, User: %d, Current: %d",
connectScheduler.getMaxConnections(), userConnLimit, res);
connectScheduler.getConnectPoolMgr().getMaxConnections(), userConnLimit, res);
context.getState().setError(ErrorCode.ERR_TOO_MANY_USER_CONNECTIONS, errMsg);
MysqlProto.sendResponsePacket(context);
throw new AfterConnectedException(errMsg);

View File

@ -867,7 +867,7 @@ public class ConnectContext {
}
this.queryId = queryId;
if (connectScheduler != null && !Strings.isNullOrEmpty(traceId)) {
connectScheduler.putTraceId2QueryId(traceId, queryId);
connectScheduler.getConnectPoolMgr().putTraceId2QueryId(traceId, queryId);
}
}

View File

@ -0,0 +1,167 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.qe;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.qe.ConnectContext.ThreadInfo;
import org.apache.doris.thrift.TUniqueId;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
public class ConnectPoolMgr {
private static final Logger LOG = LogManager.getLogger(ConnectPoolMgr.class);
protected final int maxConnections;
protected final AtomicInteger numberConnection;
protected final Map<Integer, ConnectContext> connectionMap = Maps.newConcurrentMap();
protected final Map<String, AtomicInteger> connByUser = Maps.newConcurrentMap();
// valid trace id -> query id
protected final Map<String, TUniqueId> traceId2QueryId = Maps.newConcurrentMap();
public ConnectPoolMgr(int maxConnections) {
this.maxConnections = maxConnections;
numberConnection = new AtomicInteger(0);
}
public void timeoutChecker(long now) {
for (ConnectContext connectContext : connectionMap.values()) {
connectContext.checkTimeout(now);
}
}
// Register one connection with its connection id.
// Return -1 means register OK
// Return >=0 means register failed, and return value is current connection num.
public int registerConnection(ConnectContext ctx) {
if (numberConnection.incrementAndGet() > maxConnections) {
numberConnection.decrementAndGet();
return numberConnection.get();
}
// Check user
connByUser.putIfAbsent(ctx.getQualifiedUser(), new AtomicInteger(0));
AtomicInteger conns = connByUser.get(ctx.getQualifiedUser());
if (conns.incrementAndGet() > ctx.getEnv().getAuth().getMaxConn(ctx.getQualifiedUser())) {
conns.decrementAndGet();
numberConnection.decrementAndGet();
return numberConnection.get();
}
connectionMap.put(ctx.getConnectionId(), ctx);
return -1;
}
public void unregisterConnection(ConnectContext ctx) {
ctx.closeTxn();
if (connectionMap.remove(ctx.getConnectionId()) != null) {
AtomicInteger conns = connByUser.get(ctx.getQualifiedUser());
if (conns != null) {
conns.decrementAndGet();
}
numberConnection.decrementAndGet();
}
}
public ConnectContext getContext(int connectionId) {
return connectionMap.get(connectionId);
}
public ConnectContext getContextWithQueryId(String queryId) {
for (ConnectContext context : connectionMap.values()) {
if (queryId.equals(DebugUtil.printId(context.queryId))) {
return context;
}
}
return null;
}
public boolean cancelQuery(String queryId, String cancelReason) {
for (ConnectContext ctx : connectionMap.values()) {
TUniqueId qid = ctx.queryId();
if (qid != null && DebugUtil.printId(qid).equals(queryId)) {
ctx.cancelQuery(cancelReason);
return true;
}
}
return false;
}
public int getConnectionNum() {
return numberConnection.get();
}
public List<ThreadInfo> listConnection(String user, boolean isFull) {
List<ConnectContext.ThreadInfo> infos = Lists.newArrayList();
for (ConnectContext ctx : connectionMap.values()) {
// Check auth
if (!ctx.getQualifiedUser().equals(user) && !Env.getCurrentEnv().getAccessManager()
.checkGlobalPriv(ConnectContext.get(), PrivPredicate.ADMIN)) {
continue;
}
infos.add(ctx.toThreadInfo(isFull));
}
return infos;
}
// used for thrift
public List<List<String>> listConnectionForRpc(UserIdentity userIdentity, boolean isShowFullSql,
boolean isShowFeHost) {
List<List<String>> list = new ArrayList<>();
long nowMs = System.currentTimeMillis();
for (ConnectContext ctx : connectionMap.values()) {
// Check auth
if (!ctx.getCurrentUserIdentity().equals(userIdentity) && !Env.getCurrentEnv().getAccessManager()
.checkGlobalPriv(userIdentity, PrivPredicate.GRANT)) {
continue;
}
list.add(ctx.toThreadInfo(isShowFullSql).toRow(-1, nowMs, isShowFeHost));
}
return list;
}
public void putTraceId2QueryId(String traceId, TUniqueId queryId) {
traceId2QueryId.put(traceId, queryId);
}
public String getQueryIdByTraceId(String traceId) {
TUniqueId queryId = traceId2QueryId.get(traceId);
return queryId == null ? "" : DebugUtil.printId(queryId);
}
public Map<Integer, ConnectContext> getConnectionMap() {
return connectionMap;
}
public Map<String, AtomicInteger> getUserConnectionMap() {
return connByUser;
}
public int getMaxConnections() {
return maxConnections;
}
}

View File

@ -18,13 +18,12 @@
package org.apache.doris.qe;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.Config;
import org.apache.doris.common.ThreadPoolManager;
import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.qe.ConnectContext.ConnectType;
import org.apache.doris.thrift.TUniqueId;
import org.apache.doris.qe.ConnectContext.ThreadInfo;
import org.apache.doris.service.arrowflight.sessions.FlightSqlConnectPoolMgr;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.logging.log4j.LogManager;
@ -43,15 +42,9 @@ import java.util.concurrent.atomic.AtomicInteger;
// TODO(zhaochun): We should consider if the number of local file connection can >= maximum connections later.
public class ConnectScheduler {
private static final Logger LOG = LogManager.getLogger(ConnectScheduler.class);
private final int maxConnections;
private final AtomicInteger numberConnection;
private final AtomicInteger nextConnectionId;
private final Map<Integer, ConnectContext> connectionMap = Maps.newConcurrentMap();
private final Map<String, AtomicInteger> connByUser = Maps.newConcurrentMap();
private final Map<String, Integer> flightToken2ConnectionId = Maps.newConcurrentMap();
// valid trace id -> query id
private final Map<String, TUniqueId> traceId2QueryId = Maps.newConcurrentMap();
private final ConnectPoolMgr connectPoolMgr;
private final FlightSqlConnectPoolMgr flightSqlConnectPoolMgr;
// Use a thread to check whether connection is timeout. Because
// 1. If use a scheduler, the task maybe a huge number when query is messy.
@ -60,24 +53,26 @@ public class ConnectScheduler {
private final ScheduledExecutorService checkTimer = ThreadPoolManager.newDaemonScheduledThreadPool(1,
"connect-scheduler-check-timer", true);
public ConnectScheduler(int maxConnections) {
this.maxConnections = maxConnections;
numberConnection = new AtomicInteger(0);
public ConnectScheduler(int commonMaxConnections, int flightSqlMaxConnections) {
nextConnectionId = new AtomicInteger(0);
this.connectPoolMgr = new ConnectPoolMgr(commonMaxConnections);
this.flightSqlConnectPoolMgr = new FlightSqlConnectPoolMgr(flightSqlMaxConnections);
checkTimer.scheduleAtFixedRate(new TimeoutChecker(), 0, 1000L, TimeUnit.MILLISECONDS);
}
private class TimeoutChecker extends TimerTask {
@Override
public void run() {
long now = System.currentTimeMillis();
for (ConnectContext connectContext : connectionMap.values()) {
connectContext.checkTimeout(now);
}
}
public ConnectScheduler(int commonMaxConnections) {
this(commonMaxConnections, Config.arrow_flight_max_connections);
}
// submit one MysqlContext or ArrowFlightSqlContext to this scheduler.
public ConnectPoolMgr getConnectPoolMgr() {
return connectPoolMgr;
}
public FlightSqlConnectPoolMgr getFlightSqlConnectPoolMgr() {
return flightSqlConnectPoolMgr;
}
// submit one MysqlContext to this scheduler.
// return true, if this connection has been successfully submitted, otherwise return false.
// Caller should close ConnectContext if return false.
public boolean submit(ConnectContext context) {
@ -89,89 +84,38 @@ public class ConnectScheduler {
return true;
}
// Register one connection with its connection id.
// Return -1 means register OK
// Return >=0 means register failed, and return value is current connection num.
public int registerConnection(ConnectContext ctx) {
if (numberConnection.incrementAndGet() > maxConnections) {
numberConnection.decrementAndGet();
return numberConnection.get();
}
// Check user
connByUser.putIfAbsent(ctx.getQualifiedUser(), new AtomicInteger(0));
AtomicInteger conns = connByUser.get(ctx.getQualifiedUser());
if (conns.incrementAndGet() > ctx.getEnv().getAuth().getMaxConn(ctx.getQualifiedUser())) {
conns.decrementAndGet();
numberConnection.decrementAndGet();
return numberConnection.get();
}
connectionMap.put(ctx.getConnectionId(), ctx);
if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) {
flightToken2ConnectionId.put(ctx.getPeerIdentity(), ctx.getConnectionId());
}
return -1;
}
public void unregisterConnection(ConnectContext ctx) {
ctx.closeTxn();
if (connectionMap.remove(ctx.getConnectionId()) != null) {
AtomicInteger conns = connByUser.get(ctx.getQualifiedUser());
if (conns != null) {
conns.decrementAndGet();
}
numberConnection.decrementAndGet();
if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) {
flightToken2ConnectionId.remove(ctx.getPeerIdentity());
}
}
}
public ConnectContext getContext(int connectionId) {
return connectionMap.get(connectionId);
ConnectContext ctx = connectPoolMgr.getContext(connectionId);
if (ctx == null) {
ctx = flightSqlConnectPoolMgr.getContext(connectionId);
}
return ctx;
}
public ConnectContext getContextWithQueryId(String queryId) {
for (ConnectContext context : connectionMap.values()) {
if (queryId.equals(DebugUtil.printId(context.queryId))) {
return context;
}
ConnectContext ctx = connectPoolMgr.getContextWithQueryId(queryId);
if (ctx == null) {
ctx = flightSqlConnectPoolMgr.getContextWithQueryId(queryId);
}
return null;
return ctx;
}
public ConnectContext getContext(String flightToken) {
if (flightToken2ConnectionId.containsKey(flightToken)) {
int connectionId = flightToken2ConnectionId.get(flightToken);
return getContext(connectionId);
}
return null;
}
public void cancelQuery(String queryId, String cancelReason) {
for (ConnectContext ctx : connectionMap.values()) {
TUniqueId qid = ctx.queryId();
if (qid != null && DebugUtil.printId(qid).equals(queryId)) {
ctx.cancelQuery(cancelReason);
break;
}
public boolean cancelQuery(String queryId, String cancelReason) {
boolean ret = connectPoolMgr.cancelQuery(queryId, cancelReason);
if (!ret) {
ret = flightSqlConnectPoolMgr.cancelQuery(queryId, cancelReason);
}
return ret;
}
public int getConnectionNum() {
return numberConnection.get();
return connectPoolMgr.getConnectionNum() + flightSqlConnectPoolMgr.getConnectionNum();
}
public List<ConnectContext.ThreadInfo> listConnection(String user, boolean isFull) {
public List<ThreadInfo> listConnection(String user, boolean isFull) {
List<ConnectContext.ThreadInfo> infos = Lists.newArrayList();
for (ConnectContext ctx : connectionMap.values()) {
// Check auth
if (!ctx.getQualifiedUser().equals(user) && !Env.getCurrentEnv().getAccessManager()
.checkGlobalPriv(ConnectContext.get(), PrivPredicate.ADMIN)) {
continue;
}
infos.add(ctx.toThreadInfo(isFull));
}
infos.addAll(connectPoolMgr.listConnection(user, isFull));
infos.addAll(flightSqlConnectPoolMgr.listConnection(user, isFull));
return infos;
}
@ -179,33 +123,39 @@ public class ConnectScheduler {
public List<List<String>> listConnectionForRpc(UserIdentity userIdentity, boolean isShowFullSql,
boolean isShowFeHost) {
List<List<String>> list = new ArrayList<>();
long nowMs = System.currentTimeMillis();
for (ConnectContext ctx : connectionMap.values()) {
// Check auth
if (!ctx.getCurrentUserIdentity().equals(userIdentity) && !Env.getCurrentEnv()
.getAccessManager()
.checkGlobalPriv(userIdentity, PrivPredicate.GRANT)) {
continue;
}
list.add(ctx.toThreadInfo(isShowFullSql).toRow(-1, nowMs, isShowFeHost));
}
list.addAll(connectPoolMgr.listConnectionForRpc(userIdentity, isShowFullSql, isShowFeHost));
list.addAll(flightSqlConnectPoolMgr.listConnectionForRpc(userIdentity, isShowFullSql, isShowFeHost));
return list;
}
public void putTraceId2QueryId(String traceId, TUniqueId queryId) {
traceId2QueryId.put(traceId, queryId);
}
public String getQueryIdByTraceId(String traceId) {
TUniqueId queryId = traceId2QueryId.get(traceId);
return queryId == null ? "" : DebugUtil.printId(queryId);
String queryId = connectPoolMgr.getQueryIdByTraceId(traceId);
if (Strings.isNullOrEmpty(queryId)) {
queryId = flightSqlConnectPoolMgr.getQueryIdByTraceId(traceId);
}
return queryId;
}
public Map<Integer, ConnectContext> getConnectionMap() {
return connectionMap;
Map<Integer, ConnectContext> map = Maps.newConcurrentMap();
map.putAll(connectPoolMgr.getConnectionMap());
map.putAll(flightSqlConnectPoolMgr.getConnectionMap());
return map;
}
public int getMaxConnections() {
return maxConnections;
public Map<String, AtomicInteger> getUserConnectionMap() {
Map<String, AtomicInteger> map = Maps.newConcurrentMap();
map.putAll(connectPoolMgr.getUserConnectionMap());
map.putAll(flightSqlConnectPoolMgr.getUserConnectionMap());
return map;
}
private class TimeoutChecker extends TimerTask {
@Override
public void run() {
long now = System.currentTimeMillis();
connectPoolMgr.timeoutChecker(now);
flightSqlConnectPoolMgr.timeoutChecker(now);
}
}
}

View File

@ -39,7 +39,7 @@ public class ExecuteEnv {
private ExecuteEnv() {
multiLoadMgr = new MultiLoadMgr();
scheduler = new ConnectScheduler(Config.qe_max_connection);
scheduler = new ConnectScheduler(Config.qe_max_connection, Config.arrow_flight_max_connections);
startupTime = System.currentTimeMillis();
processUUID = System.currentTimeMillis();
String logDir = Strings.isNullOrEmpty(Config.sys_log_dir) ? System.getenv("LOG_DIR") :

View File

@ -36,6 +36,8 @@ import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.CloseSessionRequest;
import org.apache.arrow.flight.CloseSessionResult;
import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightEndpoint;
@ -110,7 +112,7 @@ public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable
this.flightSessionsManager = flightSessionsManager;
sqlInfoBuilder = new SqlInfoBuilder();
sqlInfoBuilder.withFlightSqlServerName("DorisFE").withFlightSqlServerVersion("1.0")
.withFlightSqlServerArrowVersion("13.0").withFlightSqlServerReadOnly(false)
.withFlightSqlServerArrowVersion("18.2.0").withFlightSqlServerReadOnly(false)
.withSqlIdentifierQuoteChar("`").withSqlDdlCatalog(true).withSqlDdlSchema(false).withSqlDdlTable(false)
.withSqlIdentifierCase(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE)
.withSqlQuotedIdentifierCase(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE);
@ -139,7 +141,7 @@ public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable
final VectorSchemaRoot vectorSchemaRoot = flightSqlResultCacheEntry.getVectorSchemaRoot();
listener.start(vectorSchemaRoot);
listener.putNext();
} catch (Exception e) {
} catch (Throwable e) {
String errMsg = "get stream statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(e)
+ ", error code: " + connectContext.getState().getErrorCode() + ", error msg: "
+ connectContext.getState().getErrorMessage();
@ -172,7 +174,7 @@ public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable
String executedPeerIdentity = handleParts[0];
String preparedStatementId = handleParts[1];
flightSessionsManager.getConnectContext(executedPeerIdentity).removePreparedQuery(preparedStatementId);
} catch (final Exception e) {
} catch (final Throwable e) {
listener.onError(e);
return;
}
@ -274,7 +276,7 @@ public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable
return new FlightInfo(flightSQLConnectProcessor.getArrowSchema(), descriptor, endpoints, -1, -1);
}
}
} catch (Exception e) {
} catch (Throwable e) {
String errMsg = "get flight info statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(e)
+ ", error code: " + connectContext.getState().getErrorCode() + ", error msg: "
+ connectContext.getState().getErrorMessage();
@ -288,8 +290,14 @@ public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable
@Override
public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, final CallContext context,
final FlightDescriptor descriptor) {
ConnectContext connectContext = flightSessionsManager.getConnectContext(context.peerIdentity());
return executeQueryStatement(context.peerIdentity(), connectContext, request.getQuery(), descriptor);
try {
ConnectContext connectContext = flightSessionsManager.getConnectContext(context.peerIdentity());
return executeQueryStatement(context.peerIdentity(), connectContext, request.getQuery(), descriptor);
} catch (Throwable e) {
String errMsg = "get flight info statement failed, " + e.getMessage();
LOG.error(errMsg, e);
throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
}
}
@Override
@ -402,7 +410,7 @@ public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable
}
}
ackStream.onCompleted();
} catch (Exception e) {
} catch (Throwable e) {
String errMsg = "acceptPutPreparedStatementUpdate failed, " + e.getMessage() + ", "
+ Util.getRootCauseMessage(e);
LOG.error(errMsg, e);
@ -461,7 +469,7 @@ public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable
listener.putNext();
listener.completed();
}
} catch (final Exception e) {
} catch (final Throwable e) {
handleStreamException(e, "", listener);
}
}
@ -488,7 +496,7 @@ public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable
listener.putNext();
listener.completed();
}
} catch (final Exception e) {
} catch (final Throwable e) {
handleStreamException(e, "", listener);
}
}
@ -520,7 +528,7 @@ public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable
listener.putNext();
listener.completed();
}
} catch (final Exception e) {
} catch (final Throwable e) {
handleStreamException(e, "", listener);
}
}
@ -584,6 +592,25 @@ public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable
throw CallStatus.UNIMPLEMENTED.withDescription("getStreamCrossReference unimplemented").toRuntimeException();
}
@Override
public void closeSession(CloseSessionRequest request, final CallContext context,
final StreamListener<CloseSessionResult> listener) {
// https://github.com/apache/arrow-adbc/issues/2821
// currently FlightSqlConnection does not provide a separate interface for external calls to
// FlightSqlClient::closeSession(), nor will it automatically call closeSession
// when FlightSqlConnection::close(). Python flight sql Cursor.close() will call closeSession().
// Neither C++ nor Java seem to have similar behavior.
try {
flightSessionsManager.closeConnectContext(context.peerIdentity());
} catch (final Throwable e) {
LOG.error("closeSession failed", e);
listener.onError(
CallStatus.INTERNAL.withDescription("closeSession failed").withCause(e).toRuntimeException());
}
listener.onNext(new CloseSessionResult(CloseSessionResult.Status.CLOSED));
listener.onCompleted();
}
private <T extends Message> FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor,
final Schema schema) {
final Ticket ticket = new Ticket(Any.pack(request).toByteArray());
@ -592,7 +619,7 @@ public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable
return new FlightInfo(schema, descriptor, endpoints, -1, -1);
}
private static void handleStreamException(Exception e, String errMsg, ServerStreamListener listener) {
private static void handleStreamException(Throwable e, String errMsg, ServerStreamListener listener) {
LOG.error(errMsg, e);
listener.error(CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException());
throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();

View File

@ -46,22 +46,20 @@ public class DorisFlightSqlService {
public DorisFlightSqlService(int port) {
BufferAllocator allocator = new RootAllocator();
// arrow_flight_token_cache_size less than qe_max_connection to avoid `Reach limit of connections`.
// arrow flight sql is a stateless protocol, connection is usually not actively disconnected.
// bearer token is evict from the cache will unregister ConnectContext.
this.flightTokenManager = new FlightTokenManagerImpl(
Math.min(Config.arrow_flight_token_cache_size, Config.qe_max_connection / 2),
Config.arrow_flight_token_alive_time);
Math.min(Config.arrow_flight_max_connections, Config.arrow_flight_token_cache_size),
Config.arrow_flight_token_alive_time_second);
this.flightSessionsManager = new FlightSessionsWithTokenManager(flightTokenManager);
DorisFlightSqlProducer producer = new DorisFlightSqlProducer(
Location.forGrpcInsecure(FrontendOptions.getLocalHostAddress(), port), flightSessionsManager);
flightServer = FlightServer.builder(allocator, Location.forGrpcInsecure("0.0.0.0", port), producer)
.headerAuthenticator(new FlightBearerTokenAuthenticator(flightTokenManager)).build();
LOG.info("Arrow Flight SQL service is created, port: {}, token_cache_size: {}"
+ ", qe_max_connection: {}, token_alive_time: {}",
port, Config.arrow_flight_token_cache_size, Config.qe_max_connection,
Config.arrow_flight_token_alive_time);
LOG.info("Arrow Flight SQL service is created, port: {}, arrow_flight_max_connections: {}"
+ "arrow_flight_token_alive_time_second: {}", port, Config.arrow_flight_max_connections,
Config.arrow_flight_token_alive_time_second);
}
// start Arrow Flight SQL service, return true if success, otherwise false

View File

@ -45,6 +45,13 @@ public interface FlightSessionsManager {
*/
ConnectContext createConnectContext(String peerIdentity);
/**
* Close ConnectContext object and delete it in the local cache.
*
* @param peerIdentity identity after authorization
*/
void closeConnectContext(String peerIdentity);
static ConnectContext buildConnectContext(String peerIdentity, UserIdentity userIdentity, String remoteIP) {
ConnectContext connectContext = new FlightSqlConnectContext(peerIdentity);
connectContext.setEnv(Env.getCurrentEnv());

View File

@ -41,7 +41,8 @@ public class FlightSessionsWithTokenManager implements FlightSessionsManager {
@Override
public ConnectContext getConnectContext(String peerIdentity) {
try {
ConnectContext connectContext = ExecuteEnv.getInstance().getScheduler().getContext(peerIdentity);
ConnectContext connectContext = ExecuteEnv.getInstance().getScheduler().getFlightSqlConnectPoolMgr()
.getContextWithFlightToken(peerIdentity);
if (null == connectContext) {
connectContext = createConnectContext(peerIdentity);
return connectContext;
@ -68,18 +69,21 @@ public class FlightSessionsWithTokenManager implements FlightSessionsManager {
flightTokenDetails.getUserIdentity(), flightTokenDetails.getRemoteIp());
ConnectScheduler connectScheduler = ExecuteEnv.getInstance().getScheduler();
connectScheduler.submit(connectContext);
int res = connectScheduler.registerConnection(connectContext);
int res = connectScheduler.getFlightSqlConnectPoolMgr().registerConnection(connectContext);
if (res >= 0) {
long userConnLimit = connectContext.getEnv().getAuth().getMaxConn(connectContext.getQualifiedUser());
String errMsg = String.format(
"Reach limit of connections. Total: %d, User: %d, Current: %d. "
+ "Increase `qe_max_connection` in fe.conf or user's `max_user_connections`,"
+ " or decrease `arrow_flight_token_cache_size` "
+ "to evict unused bearer tokens and it connections faster",
connectScheduler.getMaxConnections(), userConnLimit, res);
connectContext.getState().setError(ErrorCode.ERR_TOO_MANY_USER_CONNECTIONS, errMsg);
"Register arrow flight sql connection failed, Unknown Error, the number of arrow flight "
+ "bearer tokens should be equal to arrow flight sql max connections, "
+ "max connections: %d, used: %d.",
connectScheduler.getFlightSqlConnectPoolMgr().getMaxConnections(), res);
connectContext.getState().setError(ErrorCode.ERR_UNKNOWN_ERROR, errMsg);
throw new IllegalArgumentException(errMsg);
}
return connectContext;
}
@Override
public void closeConnectContext(String peerIdentity) {
flightTokenManager.invalidateToken(peerIdentity);
}
}

View File

@ -22,7 +22,9 @@ import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.ConnectProcessor;
import org.apache.doris.service.arrowflight.results.FlightSqlChannel;
import org.apache.doris.thrift.TResultSinkType;
import org.apache.doris.thrift.TUniqueId;
import com.google.common.base.Strings;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@ -61,7 +63,7 @@ public class FlightSqlConnectContext extends ConnectContext {
if (flightSqlChannel != null) {
flightSqlChannel.close();
}
connectScheduler.unregisterConnection(this);
connectScheduler.getFlightSqlConnectPoolMgr().unregisterConnection(this);
}
// kill operation with no protect.
@ -78,6 +80,17 @@ public class FlightSqlConnectContext extends ConnectContext {
cancelQuery("arrow flight query killed by user");
}
@Override
public void setQueryId(TUniqueId queryId) {
if (this.queryId != null) {
this.lastQueryId = this.queryId.deepCopy();
}
this.queryId = queryId;
if (connectScheduler != null && !Strings.isNullOrEmpty(traceId)) {
connectScheduler.getFlightSqlConnectPoolMgr().putTraceId2QueryId(traceId, queryId);
}
}
@Override
public String getRemoteHostPortString() {
return getFlightSqlChannel().getRemoteHostPortString();

View File

@ -0,0 +1,74 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.service.arrowflight.sessions;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.ConnectContext.ConnectType;
import org.apache.doris.qe.ConnectPoolMgr;
import com.google.common.collect.Maps;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.Map;
public class FlightSqlConnectPoolMgr extends ConnectPoolMgr {
private static final Logger LOG = LogManager.getLogger(
FlightSqlConnectPoolMgr.class);
private final Map<String, Integer> flightToken2ConnectionId = Maps.newConcurrentMap();
public FlightSqlConnectPoolMgr(int maxConnections) {
super(maxConnections);
}
// Register one connection with its connection id.
// Return -1 means register OK
// Return >=0 means register failed, and return value is current connection num.
@Override
public int registerConnection(ConnectContext ctx) {
if (numberConnection.incrementAndGet() > maxConnections) {
numberConnection.decrementAndGet();
return numberConnection.get();
}
// not check user
connectionMap.put(ctx.getConnectionId(), ctx);
if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) {
flightToken2ConnectionId.put(ctx.getPeerIdentity(), ctx.getConnectionId());
}
return -1;
}
@Override
public void unregisterConnection(ConnectContext ctx) {
ctx.closeTxn();
if (connectionMap.remove(ctx.getConnectionId()) != null) {
numberConnection.decrementAndGet();
if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) {
flightToken2ConnectionId.remove(ctx.getPeerIdentity());
}
}
}
public ConnectContext getContextWithFlightToken(String flightToken) {
if (flightToken2ConnectionId.containsKey(flightToken)) {
int connectionId = flightToken2ConnectionId.get(flightToken);
return getContext(connectionId);
}
return null;
}
}

View File

@ -59,20 +59,26 @@ public class FlightTokenManagerImpl implements FlightTokenManager {
private ScheduledExecutorService cleanupExecutor;
public FlightTokenManagerImpl(final int cacheSize, final int cacheExpiration) {
// The cache size of all user tokens in Arrow Flight Server. which will be eliminated by
// LRU rules after exceeding the limit, the default value is arrow_flight_max_connections,
// arrow flight sql is a stateless protocol, the connection is usually not actively
// disconnected, bearer token is evict from the cache will unregister ConnectContext.
this.cacheSize = cacheSize;
this.cacheExpiration = cacheExpiration;
this.tokenCache = CacheBuilder.newBuilder().maximumSize(cacheSize)
.expireAfterWrite(cacheExpiration, TimeUnit.MINUTES)
.expireAfterWrite(cacheExpiration, TimeUnit.SECONDS)
.removalListener(new RemovalListener<String, FlightTokenDetails>() {
@Override
public void onRemoval(@NotNull RemovalNotification<String, FlightTokenDetails> notification) {
// TODO: broadcast this message to other FE
String token = notification.getKey();
FlightTokenDetails tokenDetails = notification.getValue();
ConnectContext context = ExecuteEnv.getInstance().getScheduler().getContext(token);
ConnectContext context = ExecuteEnv.getInstance().getScheduler().getFlightSqlConnectPoolMgr()
.getContextWithFlightToken(token);
if (context != null) {
ExecuteEnv.getInstance().getScheduler().unregisterConnection(context);
ExecuteEnv.getInstance().getScheduler().getFlightSqlConnectPoolMgr()
.unregisterConnection(context);
LOG.info("evict bearer token: " + token + " from tokenCache, reason: "
+ notification.getCause()
+ ", and unregister flight connection context after evict bearer token");
@ -145,13 +151,13 @@ public class FlightTokenManagerImpl implements FlightTokenManager {
if (value.getToken().equals("")) {
throw new IllegalArgumentException("invalid bearer token: " + token
+ ", try reconnect, bearer token may not be created, or may have been evict, search for this "
+ "token in fe.log to see the evict reason. currently in fe.conf, `arrow_flight_token_cache_size`="
+ this.cacheSize + ", `arrow_flight_token_alive_time`=" + this.cacheExpiration);
+ "token in fe.log to see the evict reason. currently in fe.conf, `arrow_flight_max_connections`="
+ this.cacheSize + ", `arrow_flight_token_alive_time_second`=" + this.cacheExpiration);
}
if (System.currentTimeMillis() >= value.getExpiresAt()) {
tokenCache.invalidate(token);
throw new IllegalArgumentException("bearer token expired: " + token + ", try reconnect, "
+ "currently in fe.conf, `arrow_flight_token_alive_time`=" + this.cacheExpiration);
+ "currently in fe.conf, `arrow_flight_token_alive_time_second`=" + this.cacheExpiration);
}
if (usersTokenLRU.containsKey(value.getUsername())) {
try {

View File

@ -48,8 +48,8 @@ under the License.
similar issue: https://github.com/protocolbuffers/protobuf/issues/15762
3. A more stable version is Arrow 15.0.2 and ADBC 0.12.0, but we always hope to embrace the future with new versions!
-->
<arrow.version>18.1.0</arrow.version>
<adbc.version>0.15.0</adbc.version>
<arrow.version>18.2.0</arrow.version>
<adbc.version>0.18.0</adbc.version>
<log4j.version>2.17.1</log4j.version>
</properties>
<dependencies>

8
thirdparty/vars.sh vendored
View File

@ -253,10 +253,10 @@ GRPC_SOURCE=grpc-1.54.3
GRPC_MD5SUM="af00a2edeae0f02bb25917cc3473b7de"
# arrow
ARROW_DOWNLOAD="https://github.com/apache/arrow/archive/refs/tags/apache-arrow-17.0.0.tar.gz"
ARROW_NAME="apache-arrow-17.0.0.tar.gz"
ARROW_SOURCE="arrow-apache-arrow-17.0.0"
ARROW_MD5SUM="ba18bf83e2164abd34b9ac4cb164f0f0"
ARROW_DOWNLOAD="https://github.com/apache/arrow/archive/refs/tags/apache-arrow-19.0.1.tar.gz"
ARROW_NAME="apache-arrow-19.0.1.tar.gz"
ARROW_SOURCE="arrow-apache-arrow-19.0.1"
ARROW_MD5SUM="8c5091da0f8fb41a47d7f4dad7b712df"
# Abseil
ABSEIL_DOWNLOAD="https://github.com/abseil/abseil-cpp/archive/refs/tags/20230125.3.tar.gz"