[UDF] Fix bug that UDF can't handle constant null value (#2914)

This CL modify the `evalExpr()` of ExpressionFunctions, so that it won't change the
`FunctionCallExpr` to `NullLiteral` when there is null parameter in UDF. Which will fix the
problem described in ISSUE: #2913
This commit is contained in:
Mingyu Chen
2020-02-17 22:13:50 +08:00
committed by GitHub
parent 1089f09d26
commit 0fb52c514b
11 changed files with 202 additions and 43 deletions

View File

@ -17,9 +17,6 @@
package org.apache.doris.analysis;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableSortedMap;
import org.apache.commons.codec.binary.Hex;
import org.apache.doris.catalog.AggregateFunction;
import org.apache.doris.catalog.Catalog;
import org.apache.doris.catalog.Function;
@ -31,6 +28,11 @@ import org.apache.doris.common.UserException;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableSortedMap;
import org.apache.commons.codec.binary.Hex;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;

View File

@ -17,18 +17,20 @@
package org.apache.doris.analysis;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.doris.catalog.Catalog;
import org.apache.doris.catalog.Function;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.rewrite.FEFunction;
import org.apache.doris.rewrite.FEFunctions;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@ -46,12 +48,6 @@ public enum ExpressionFunctions {
private static final Logger LOG = LogManager.getLogger(ExpressionFunctions.class);
private ImmutableMultimap<String, FEFunctionInvoker> functions;
// For most build-in functions, it will return NullLiteral when params contain NullLiteral.
// But a few functions need to handle NullLiteral differently, such as "if". It need to add
// an attribute to LiteralExpr to mark null and check the attribute to decide whether to
// replace the result with NullLiteral when function finished. It leaves to be realized.
// TODO chenhao16.
private ImmutableSet<String> nonNullResultWithNullParamFunctions;
private ExpressionFunctions() {
registerFunctions();
@ -71,8 +67,13 @@ public enum ExpressionFunctions {
Function fn = constExpr.getFn();
Preconditions.checkNotNull(fn, "Expr's fn can't be null.");
// null
if (!nonNullResultWithNullParamFunctions.contains(fn.getFunctionName().getFunction())) {
// return NullLiteral directly iff:
// 1. Not UDF
// 2. Not in NonNullResultWithNullParamFunctions
// 3. Has null parameter
if (!Catalog.getCurrentCatalog().isNonNullResultWithNullParamFunction(fn.getFunctionName().getFunction())
&& !fn.isUdf()) {
for (Expr e : constExpr.getChildren()) {
if (e instanceof NullLiteral) {
return new NullLiteral();
@ -144,15 +145,6 @@ public enum ExpressionFunctions {
}
}
this.functions = mapBuilder.build();
// Functions that need to handle null.
ImmutableSet.Builder<String> setBuilder =
new ImmutableSet.Builder<String>();
setBuilder.add("if");
setBuilder.add("hll_hash");
setBuilder.add("concat_ws");
setBuilder.add("ifnull");
this.nonNullResultWithNullParamFunctions = setBuilder.build();
}
public static class FEFunctionInvoker {

View File

@ -5317,6 +5317,10 @@ public class Catalog {
return functionSet.getBulitinFunctions();
}
public boolean isNonNullResultWithNullParamFunction(String funcName) {
return functionSet.isNonNullResultWithNullParamFunctions(funcName);
}
/**
* create cluster
*

View File

@ -19,15 +19,17 @@ package org.apache.doris.catalog;
import static org.apache.doris.common.io.IOUtils.writeOptionString;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.apache.doris.analysis.FunctionName;
import org.apache.doris.analysis.HdfsURI;
import org.apache.doris.common.io.Text;
import org.apache.doris.common.io.Writable;
import org.apache.doris.thrift.TFunction;
import org.apache.doris.thrift.TFunctionBinaryType;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@ -208,6 +210,13 @@ public class Function implements Writable {
public void setChecksum(String checksum) { this.checksum = checksum; }
public String getChecksum() { return checksum; }
// TODO(cmy): Currently we judge whether it is UDF by wheter the 'location' is set.
// Maybe we should use a separate variable to identify,
// but additional variables need to modify the persistence information.
public boolean isUdf() {
return location != null;
}
// Returns a string with the signature in human readable format:
// FnName(argtype1, argtyp2). e.g. Add(int, int)
public String signatureString() {

View File

@ -17,9 +17,6 @@
package org.apache.doris.catalog;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.doris.analysis.ArithmeticExpr;
import org.apache.doris.analysis.BinaryPredicate;
import org.apache.doris.analysis.CastExpr;
@ -27,14 +24,20 @@ import org.apache.doris.analysis.InPredicate;
import org.apache.doris.analysis.IsNullPredicate;
import org.apache.doris.analysis.LikePredicate;
import org.apache.doris.builtins.ScalarBuiltins;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class FunctionSet {
private static final Logger LOG = LogManager.getLogger(FunctionSet.class);
@ -46,6 +49,16 @@ public class FunctionSet {
// FunctionResolutionOrder.
private final HashMap<String, List<Function>> functions;
// For most build-in functions, it will return NullLiteral when params contain NullLiteral.
// But a few functions need to handle NullLiteral differently, such as "if". It need to add
// an attribute to LiteralExpr to mark null and check the attribute to decide whether to
// replace the result with NullLiteral when function finished. It leaves to be realized.
// Functions in this set is defined in `gensrc/script/doris_builtins_functions.py`,
// and will be built automatically.
// cmy: This does not contain any user defined functions. All UDFs handle null values by themselves.
private ImmutableSet<String> nonNullResultWithNullParamFunctions;
public FunctionSet() {
functions = Maps.newHashMap();
}
@ -63,6 +76,18 @@ public class FunctionSet {
InPredicate.initBuiltins(this);
}
public void buildNonNullResultWithNullParamFunction(Set<String> funcNames) {
ImmutableSet.Builder<String> setBuilder = new ImmutableSet.Builder<String>();
for (String funcName : funcNames) {
setBuilder.add(funcName);
}
this.nonNullResultWithNullParamFunctions = setBuilder.build();
}
public boolean isNonNullResultWithNullParamFunctions(String funcName) {
return nonNullResultWithNullParamFunctions.contains(funcName);
}
private static final Map<Type, String> MIN_UPDATE_SYMBOL =
ImmutableMap.<Type, String>builder()
.put(Type.BOOLEAN,
@ -746,8 +771,7 @@ public class FunctionSet {
return null;
}
// Only used
public boolean addFunction(Function fn) {
private boolean addFunction(Function fn, boolean isBuiltin) {
// TODO: add this to persistent store
if (getFunction(fn, Function.CompareMode.IS_INDISTINGUISHABLE) != null) {
return false;
@ -791,7 +815,7 @@ public class FunctionSet {
* Adds a builtin to this database. The function must not already exist.
*/
public void addBuiltin(Function fn) {
addFunction(fn);
addFunction(fn, true);
}
// Populate all the aggregate builtins in the catalog.

View File

@ -547,7 +547,7 @@ public class StmtExecutor {
coord.exec();
// if python's MysqlDb get error after sendfields, it can't catch the excpetion
// if python's MysqlDb get error after sendfields, it can't catch the exception
// so We need to send fields after first batch arrived
// send result

View File

@ -40,8 +40,6 @@ public class BatchRollupJobTest {
private static String runningDir = "fe/mocked/BatchRollupJobTest/" + UUID.randomUUID().toString() + "/";
private static ConnectContext ctx = UtFrameUtils.createDefaultCtx();
@BeforeClass
public static void setup() throws Exception {
UtFrameUtils.createMinDorisCluster(runningDir);
@ -49,7 +47,7 @@ public class BatchRollupJobTest {
@Test
public void test() throws Exception {
System.out.println("xxx");
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
// create database db1
String createDbStmtStr = "create database db1;";
CreateDbStmt createDbStmt = (CreateDbStmt) UtFrameUtils.parseAndAnalyzeStmt(createDbStmtStr, ctx);

View File

@ -0,0 +1,105 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.catalog;
import org.apache.doris.analysis.CreateDbStmt;
import org.apache.doris.analysis.CreateFunctionStmt;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.common.jmockit.Deencapsulation;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.planner.Planner;
import org.apache.doris.planner.UnionNode;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.QueryState;
import org.apache.doris.qe.StmtExecutor;
import org.apache.doris.utframe.UtFrameUtils;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import java.io.File;
import java.util.List;
import java.util.UUID;
/*
* Author: Chenmingyu
* Date: Feb 16, 2020
*/
public class CreateFunctionTest {
private static String runningDir = "fe/mocked/CreateFunctionTest/" + UUID.randomUUID().toString() + "/";
@BeforeClass
public static void setup() throws Exception {
UtFrameUtils.createMinDorisCluster(runningDir);
}
@AfterClass
public static void teardown() {
File file = new File("fe/mocked/CreateFunctionTest/");
file.delete();
}
@Test
public void test() throws Exception {
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
// create database db1
String createDbStmtStr = "create database db1;";
CreateDbStmt createDbStmt = (CreateDbStmt) UtFrameUtils.parseAndAnalyzeStmt(createDbStmtStr, ctx);
Catalog.getCurrentCatalog().createDb(createDbStmt);
System.out.println(Catalog.getCurrentCatalog().getDbNames());
Database db = Catalog.getCurrentCatalog().getDb("default_cluster:db1");
Assert.assertNotNull(db);
String createFuncStr = "create function db1.my_add(VARCHAR(1024)) RETURNS BOOLEAN properties\n" +
"(\n" +
"\"symbol\" = \"_ZN9doris_udf6AddUdfEPNS_15FunctionContextERKNS_9StringValE\",\n" +
"\"prepare_fn\" = \"_ZN9doris_udf13AddUdfPrepareEPNS_15FunctionContextENS0_18FunctionStateScopeE\",\n" +
"\"close_fn\" = \"_ZN9doris_udf11AddUdfCloseEPNS_15FunctionContextENS0_18FunctionStateScopeE\",\n" +
"\"object_file\" = \"http://nmg01-inf-dorishb00.nmg01.baidu.com:8456/libcmy_udf.so\"\n" +
");";
CreateFunctionStmt createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
Catalog.getCurrentCatalog().createFunction(createFunctionStmt);
List<Function> functions = db.getFunctions();
Assert.assertEquals(1, functions.size());
Assert.assertTrue(functions.get(0).isUdf());
String queryStr = "select db1.my_add(null)";
ctx.getState().reset();
StmtExecutor stmtExecutor = new StmtExecutor(ctx, queryStr);
stmtExecutor.execute();
Assert.assertNotEquals(QueryState.MysqlStateType.ERR, ctx.getState().getStateType());
Planner planner = stmtExecutor.planner();
Assert.assertEquals(1, planner.getFragments().size());
PlanFragment fragment = planner.getFragments().get(0);
Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode);
UnionNode unionNode = (UnionNode)fragment.getPlanRoot();
List<List<Expr>> constExprLists = Deencapsulation.getField(unionNode, "constExprLists_");
Assert.assertEquals(1, constExprLists.size());
Assert.assertEquals(1, constExprLists.get(0).size());
Assert.assertTrue(constExprLists.get(0).get(0) instanceof FunctionCallExpr);
}
}

View File

@ -42,6 +42,7 @@ import com.google.common.collect.Maps;
import java.io.IOException;
import java.io.StringReader;
import java.nio.channels.SocketChannel;
import java.util.List;
import java.util.Map;
import java.util.Random;
@ -49,8 +50,9 @@ import java.util.Random;
public class UtFrameUtils {
// Help to create a mocked ConnectContext.
public static ConnectContext createDefaultCtx() {
ConnectContext ctx = new ConnectContext();
public static ConnectContext createDefaultCtx() throws IOException {
SocketChannel channel = SocketChannel.open();
ConnectContext ctx = new ConnectContext(channel);
ctx.setCluster(SystemInfoService.DEFAULT_CLUSTER);
ctx.setCurrentUserIdentity(UserIdentity.ROOT);
ctx.setQualifiedUser(PaloAuth.ROOT_USER);

View File

@ -708,5 +708,19 @@ visible_functions = [
[['grouping'], 'BIGINT', ['BIGINT'], '_ZN5doris21GroupingSetsFunctions8groupingEPN9doris_udf15FunctionContextERKNS1_9BigIntValE'],
]
# Except the following functions, other function will directly return
# null if there is null parameters.
# Functions in this set will handle null values, not just return null.
#
# This set is only used to replace 'functions with null parameters' with NullLiteral
# when applying FoldConstantsRule rules on the FE side.
# TODO(cmy): Are these functions only required to handle null values?
non_null_result_with_null_param_functions = [
'if',
'hll_hash',
'concat_ws',
'ifnull'
]
invisible_functions = [
]

View File

@ -36,6 +36,8 @@ package org.apache.doris.builtins;\n\
\n\
import org.apache.doris.catalog.PrimitiveType;\n\
import org.apache.doris.catalog.FunctionSet;\n\
import com.google.common.collect.Sets;\n\
import java.util.Set;\n\
\n\
public class ScalarBuiltins { \n\
public static void initBuiltins(FunctionSet functionSet) { \
@ -111,9 +113,16 @@ def generate_fe_registry_init(filename):
for entry in meta_data_entries:
for name in entry["sql_names"]:
java_output = generate_fe_entry(entry, name)
java_registry_file.write(" functionSet.addScalarBuiltin(%s);\n" % java_output)
java_registry_file.write(" functionSet.addScalarBuiltin(%s);\n" % java_output)
java_registry_file.write("\n")
# add non_null_result_with_null_param_functions
java_registry_file.write(" Set<String> funcNames = Sets.newHashSet();\n")
for entry in doris_builtins_functions.non_null_result_with_null_param_functions:
java_registry_file.write(" funcNames.add(\"%s\");\n" % entry)
java_registry_file.write(" functionSet.buildNonNullResultWithNullParamFunction(funcNames);\n");
java_registry_file.write(java_registry_epilogue)
java_registry_file.close()