[branch-2.1](function) Refine crypto functions signature to fix wrong result(#40285) (#40648)

pick https://github.com/apache/doris/pull/40285
This commit is contained in:
zclllhhjj
2024-09-11 15:32:19 +08:00
committed by GitHub
parent 66421c4270
commit 3246baa451
16 changed files with 29 additions and 763 deletions

View File

@ -640,11 +640,7 @@ public class FunctionCallExpr extends Expr {
&& (fnName.getFunction().equalsIgnoreCase("aes_decrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt_v2"))) {
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt"))) {
sb.append("\'***\'");
} else if (orderByElements.size() > 0 && i == len - orderByElements.size()) {
sb.append("ORDER BY ");
@ -718,22 +714,13 @@ public class FunctionCallExpr extends Expr {
if (fnName.getFunction().equalsIgnoreCase("aes_decrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt_v2")) {
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")) {
len = len - 1;
}
for (int i = 0; i < len; ++i) {
if (i == 1 && (fnName.getFunction().equalsIgnoreCase("aes_decrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt_v2"))) {
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt"))) {
result.add("\'***\'");
} else {
result.add(children.get(i).toDigest());
@ -1141,13 +1128,8 @@ public class FunctionCallExpr extends Expr {
if ((fnName.getFunction().equalsIgnoreCase("aes_decrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt_v2"))
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt"))
&& (children.size() == 2 || children.size() == 3)) {
String blockEncryptionMode = "";
Set<String> aesModes = new HashSet<>(Arrays.asList(
"AES_128_ECB",
"AES_192_ECB",
@ -1181,43 +1163,20 @@ public class FunctionCallExpr extends Expr {
"SM4_128_OFB",
"SM4_128_CTR"));
String blockEncryptionMode = "";
if (ConnectContext.get() != null) {
blockEncryptionMode = ConnectContext.get().getSessionVariable().getBlockEncryptionMode();
if (fnName.getFunction().equalsIgnoreCase("aes_decrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt_v2")) {
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt")) {
if (StringUtils.isAllBlank(blockEncryptionMode)) {
blockEncryptionMode = "AES_128_ECB";
}
if (!aesModes.contains(blockEncryptionMode.toUpperCase())) {
throw new AnalysisException("session variable block_encryption_mode is invalid with aes");
}
if (children.size() == 2) {
boolean isECB = blockEncryptionMode.equalsIgnoreCase("AES_128_ECB")
|| blockEncryptionMode.equalsIgnoreCase("AES_192_ECB")
|| blockEncryptionMode.equalsIgnoreCase("AES_256_ECB");
if (fnName.getFunction().equalsIgnoreCase("aes_decrypt_v2")) {
if (!isECB) {
throw new AnalysisException(
"Incorrect parameter count in the call to native function 'aes_decrypt'");
}
} else if (fnName.getFunction().equalsIgnoreCase("aes_encrypt_v2")) {
if (!isECB) {
throw new AnalysisException(
"Incorrect parameter count in the call to native function 'aes_encrypt'");
}
} else {
// if there are only 2 params, we need set encryption mode to AES_128_ECB
// this keeps the behavior consistent with old doris ver.
blockEncryptionMode = "AES_128_ECB";
}
}
}
if (fnName.getFunction().equalsIgnoreCase("sm4_decrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt_v2")) {
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")) {
if (StringUtils.isAllBlank(blockEncryptionMode)) {
blockEncryptionMode = "SM4_128_ECB";
}
@ -1225,36 +1184,12 @@ public class FunctionCallExpr extends Expr {
throw new AnalysisException(
"session variable block_encryption_mode is invalid with sm4");
}
if (children.size() == 2) {
if (fnName.getFunction().equalsIgnoreCase("sm4_decrypt_v2")) {
throw new AnalysisException(
"Incorrect parameter count in the call to native function 'sm4_decrypt'");
} else if (fnName.getFunction().equalsIgnoreCase("sm4_encrypt_v2")) {
throw new AnalysisException(
"Incorrect parameter count in the call to native function 'sm4_encrypt'");
} else {
// if there are only 2 params, we need add an empty string as the third param
// and set encryption mode to SM4_128_ECB
// this keeps the behavior consistent with old doris ver.
children.add(new StringLiteral(""));
blockEncryptionMode = "SM4_128_ECB";
}
}
}
} else {
throw new AnalysisException("cannot get session variable `block_encryption_mode`, "
+ "please explicitly specify by using 4-args function");
}
if (!blockEncryptionMode.equals(children.get(children.size() - 1).toString())) {
children.add(new StringLiteral(blockEncryptionMode));
}
if (fnName.getFunction().equalsIgnoreCase("aes_decrypt_v2")) {
fnName = FunctionName.createBuiltinName("aes_decrypt");
} else if (fnName.getFunction().equalsIgnoreCase("aes_encrypt_v2")) {
fnName = FunctionName.createBuiltinName("aes_encrypt");
} else if (fnName.getFunction().equalsIgnoreCase("sm4_decrypt_v2")) {
fnName = FunctionName.createBuiltinName("sm4_decrypt");
} else if (fnName.getFunction().equalsIgnoreCase("sm4_encrypt_v2")) {
fnName = FunctionName.createBuiltinName("sm4_encrypt");
}
children.add(new StringLiteral(blockEncryptionMode));
}
}

View File

@ -22,9 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Regexp;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Acos;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AesDecrypt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AesDecryptV2;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AesEncrypt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AesEncryptV2;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AppendTrailingCharIfAbsent;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayApply;
@ -358,9 +356,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Sleep;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm3;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm3sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm4Decrypt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm4DecryptV2;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm4Encrypt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm4EncryptV2;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Space;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SplitByChar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SplitByString;
@ -465,9 +461,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(Abs.class, "abs"),
scalar(Acos.class, "acos"),
scalar(AesDecrypt.class, "aes_decrypt"),
scalar(AesDecryptV2.class, "aes_decrypt_v2"),
scalar(AesEncrypt.class, "aes_encrypt"),
scalar(AesEncryptV2.class, "aes_encrypt_v2"),
scalar(AppendTrailingCharIfAbsent.class, "append_trailing_char_if_absent"),
scalar(Array.class, "array"),
scalar(ArrayApply.class, "array_apply"),
@ -823,9 +817,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(Sm3.class, "sm3"),
scalar(Sm3sum.class, "sm3sum"),
scalar(Sm4Decrypt.class, "sm4_decrypt"),
scalar(Sm4DecryptV2.class, "sm4_decrypt_v2"),
scalar(Sm4Encrypt.class, "sm4_encrypt"),
scalar(Sm4EncryptV2.class, "sm4_encrypt_v2"),
scalar(Space.class, "space"),
scalar(SplitByChar.class, "split_by_char"),
scalar(SplitByString.class, "split_by_string"),

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
@ -58,16 +57,7 @@ public class AesDecrypt extends AesCryptoFunction {
* AesDecrypt
*/
public AesDecrypt(Expression arg0, Expression arg1) {
// if there are only 2 params, we need set encryption mode to AES_128_ECB
// this keeps the behavior consistent with old doris ver.
super("aes_decrypt", arg0, arg1, new StringLiteral("AES_128_ECB"));
// check if encryptionMode from session variables is valid
StringLiteral encryptionMode = CryptoFunction.getDefaultBlockEncryptionMode("AES_128_ECB");
if (!AES_MODES.contains(encryptionMode.getValue())) {
throw new AnalysisException(
"session variable block_encryption_mode is invalid with aes");
}
super("aes_decrypt", arg0, arg1, new StringLiteral(""), getDefaultBlockEncryptionMode());
}
public AesDecrypt(Expression arg0, Expression arg1, Expression arg2) {
@ -89,7 +79,7 @@ public class AesDecrypt extends AesCryptoFunction {
} else if (children().size() == 3) {
return new AesDecrypt(children.get(0), children.get(1), children.get(2));
} else {
return new AesDecrypt(children.get(0), children.get(1), children.get(2), (StringLiteral) children.get(3));
return new AesDecrypt(children.get(0), children.get(1), children.get(2), children.get(3));
}
}

View File

@ -1,74 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import com.google.common.base.Preconditions;
import java.util.List;
/**
* ScalarFunction 'aes_decrypt'. This class is generated by GenerateFunction.
*/
public class AesDecryptV2 extends AesDecrypt {
/**
* AesDecryptV2
*/
public AesDecryptV2(Expression arg0, Expression arg1) {
super(arg0, arg1, getDefaultBlockEncryptionMode());
String blockEncryptionMode = String.valueOf(getDefaultBlockEncryptionMode());
if (!blockEncryptionMode.toUpperCase().equals("'AES_128_ECB'")
&& !blockEncryptionMode.toUpperCase().equals("'AES_192_ECB'")
&& !blockEncryptionMode.toUpperCase().equals("'AES_256_ECB'")) {
throw new AnalysisException("Incorrect parameter count in the call to native function 'aes_decrypt'");
}
}
public AesDecryptV2(Expression arg0, Expression arg1, Expression arg2) {
super(arg0, arg1, arg2);
}
public AesDecryptV2(Expression arg0, Expression arg1, Expression arg2, Expression arg3) {
super(arg0, arg1, arg2, arg3);
}
/**
* withChildren.
*/
@Override
public AesDecryptV2 withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() >= 2 && children.size() <= 4);
if (children.size() == 2) {
return new AesDecryptV2(children.get(0), children.get(1));
} else if (children().size() == 3) {
return new AesDecryptV2(children.get(0), children.get(1), children.get(2));
} else {
return new AesDecryptV2(children.get(0), children.get(1), children.get(2), (StringLiteral) children.get(3));
}
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitAesDecryptV2(this, context);
}
}

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
@ -58,16 +57,7 @@ public class AesEncrypt extends AesCryptoFunction {
* Some javadoc for checkstyle...
*/
public AesEncrypt(Expression arg0, Expression arg1) {
// if there are only 2 params, we need set encryption mode to AES_128_ECB
// this keeps the behavior consistent with old doris ver.
super("aes_encrypt", arg0, arg1, new StringLiteral("AES_128_ECB"));
// check if encryptionMode from session variables is valid
StringLiteral encryptionMode = CryptoFunction.getDefaultBlockEncryptionMode("AES_128_ECB");
if (!AES_MODES.contains(encryptionMode.getValue())) {
throw new AnalysisException(
"session variable block_encryption_mode is invalid with aes");
}
super("aes_encrypt", arg0, arg1, new StringLiteral(""), getDefaultBlockEncryptionMode());
}
public AesEncrypt(Expression arg0, Expression arg1, Expression arg2) {
@ -89,7 +79,7 @@ public class AesEncrypt extends AesCryptoFunction {
} else if (children().size() == 3) {
return new AesEncrypt(children.get(0), children.get(1), children.get(2));
} else {
return new AesEncrypt(children.get(0), children.get(1), children.get(2), (StringLiteral) children.get(3));
return new AesEncrypt(children.get(0), children.get(1), children.get(2), children.get(3));
}
}

View File

@ -1,74 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import com.google.common.base.Preconditions;
import java.util.List;
/**
* ScalarFunction 'aes_encrypt'. This class is generated by GenerateFunction.
*/
public class AesEncryptV2 extends AesEncrypt {
/**
* AesEncryptV2
*/
public AesEncryptV2(Expression arg0, Expression arg1) {
super(arg0, arg1, getDefaultBlockEncryptionMode());
String blockEncryptionMode = String.valueOf(getDefaultBlockEncryptionMode());
if (!blockEncryptionMode.toUpperCase().equals("'AES_128_ECB'")
&& !blockEncryptionMode.toUpperCase().equals("'AES_192_ECB'")
&& !blockEncryptionMode.toUpperCase().equals("'AES_256_ECB'")) {
throw new AnalysisException("Incorrect parameter count in the call to native function 'aes_encrypt'");
}
}
public AesEncryptV2(Expression arg0, Expression arg1, Expression arg2) {
super(arg0, arg1, arg2);
}
public AesEncryptV2(Expression arg0, Expression arg1, Expression arg2, Expression arg3) {
super(arg0, arg1, arg2, arg3);
}
/**
* withChildren.
*/
@Override
public AesEncryptV2 withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() >= 2 && children.size() <= 4);
if (children.size() == 2) {
return new AesEncryptV2(children.get(0), children.get(1));
} else if (children().size() == 3) {
return new AesEncryptV2(children.get(0), children.get(1), children.get(2));
} else {
return new AesEncryptV2(children.get(0), children.get(1), children.get(2), (StringLiteral) children.get(3));
}
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitAesEncryptV2(this, context);
}
}

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
@ -63,17 +62,7 @@ public class Sm4Decrypt extends Sm4CryptoFunction {
* constructor with 2 arguments.
*/
public Sm4Decrypt(Expression arg0, Expression arg1) {
// if there are only 2 params, we need add an empty string as the third param
// and set encryption mode to SM4_128_ECB
// this keeps the behavior consistent with old doris ver.
super("sm4_decrypt", arg0, arg1, new StringLiteral(""), new StringLiteral("SM4_128_ECB"));
// check if encryptionMode from session variables is valid
StringLiteral encryptionMode = CryptoFunction.getDefaultBlockEncryptionMode("SM4_128_ECB");
if (!SM4_MODES.contains(encryptionMode.getValue())) {
throw new AnalysisException(
"session variable block_encryption_mode is invalid with sm4");
}
super("sm4_decrypt", arg0, arg1, new StringLiteral(""), getDefaultBlockEncryptionMode());
}
/**
@ -98,7 +87,7 @@ public class Sm4Decrypt extends Sm4CryptoFunction {
} else if (children().size() == 3) {
return new Sm4Decrypt(children.get(0), children.get(1), children.get(2));
} else {
return new Sm4Decrypt(children.get(0), children.get(1), children.get(2), (StringLiteral) children.get(3));
return new Sm4Decrypt(children.get(0), children.get(1), children.get(2), children.get(3));
}
}

View File

@ -1,69 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import com.google.common.base.Preconditions;
import java.util.List;
/**
* ScalarFunction 'sm4_decrypt'. This class is generated by GenerateFunction.
*/
public class Sm4DecryptV2 extends Sm4Decrypt {
/**
* Sm4DecryptV2
*/
public Sm4DecryptV2(Expression arg0, Expression arg1) {
super(arg0, arg1);
throw new AnalysisException("Incorrect parameter count in the call to native function 'sm4_decrypt'");
}
public Sm4DecryptV2(Expression arg0, Expression arg1, Expression arg2) {
super(arg0, arg1, arg2);
}
public Sm4DecryptV2(Expression arg0, Expression arg1, Expression arg2, Expression arg3) {
super(arg0, arg1, arg2, arg3);
}
/**
* withChildren.
*/
@Override
public Sm4DecryptV2 withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() >= 2 && children.size() <= 4);
if (children.size() == 2) {
return new Sm4DecryptV2(children.get(0), children.get(1));
} else if (children().size() == 3) {
return new Sm4DecryptV2(children.get(0), children.get(1), children.get(2));
} else {
return new Sm4DecryptV2(children.get(0), children.get(1), children.get(2), (StringLiteral) children.get(3));
}
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitSm4DecryptV2(this, context);
}
}

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
@ -58,17 +57,7 @@ public class Sm4Encrypt extends Sm4CryptoFunction {
* constructor with 2 arguments.
*/
public Sm4Encrypt(Expression arg0, Expression arg1) {
// if there are only 2 params, we need add an empty string as the third param
// and set encryption mode to SM4_128_ECB
// this keeps the behavior consistent with old doris ver.
super("sm4_encrypt", arg0, arg1, new StringLiteral(""), new StringLiteral("SM4_128_ECB"));
// check if encryptionMode from session variables is valid
StringLiteral encryptionMode = CryptoFunction.getDefaultBlockEncryptionMode("SM4_128_ECB");
if (!SM4_MODES.contains(encryptionMode.getValue())) {
throw new AnalysisException(
"session variable block_encryption_mode is invalid with sm4");
}
super("sm4_encrypt", arg0, arg1, new StringLiteral(""), getDefaultBlockEncryptionMode());
}
/**
@ -93,7 +82,7 @@ public class Sm4Encrypt extends Sm4CryptoFunction {
} else if (children().size() == 3) {
return new Sm4Encrypt(children.get(0), children.get(1), children.get(2));
} else {
return new Sm4Encrypt(children.get(0), children.get(1), children.get(2), (StringLiteral) children.get(3));
return new Sm4Encrypt(children.get(0), children.get(1), children.get(2), children.get(3));
}
}

View File

@ -1,72 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import com.google.common.base.Preconditions;
import java.util.List;
/**
* ScalarFunction 'sm4_encrypt'. This class is generated by GenerateFunction.
*/
public class Sm4EncryptV2 extends Sm4Encrypt {
/**
* constructor with 2 arguments.
*/
public Sm4EncryptV2(Expression arg0, Expression arg1) {
super(arg0, arg1);
throw new AnalysisException("Incorrect parameter count in the call to native function 'sm4_encrypt'");
}
/**
* constructor with 3 arguments.
*/
public Sm4EncryptV2(Expression arg0, Expression arg1, Expression arg2) {
super(arg0, arg1, arg2);
}
public Sm4EncryptV2(Expression arg0, Expression arg1, Expression arg2, Expression arg3) {
super(arg0, arg1, arg2, arg3);
}
/**
* withChildren.
*/
@Override
public Sm4EncryptV2 withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() >= 2 && children.size() <= 4);
if (children.size() == 2) {
return new Sm4EncryptV2(children.get(0), children.get(1));
} else if (children().size() == 3) {
return new Sm4EncryptV2(children.get(0), children.get(1), children.get(2));
} else {
return new Sm4EncryptV2(children.get(0), children.get(1), children.get(2), (StringLiteral) children.get(3));
}
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitSm4EncryptV2(this, context);
}
}

View File

@ -24,9 +24,7 @@ import org.apache.doris.nereids.trees.expressions.functions.combinator.StateComb
import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Acos;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AesDecrypt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AesDecryptV2;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AesEncrypt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AesEncryptV2;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AppendTrailingCharIfAbsent;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayApply;
@ -356,9 +354,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Sleep;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm3;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm3sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm4Decrypt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm4DecryptV2;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm4Encrypt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm4EncryptV2;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Space;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SplitByChar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SplitByString;
@ -466,18 +462,10 @@ public interface ScalarFunctionVisitor<R, C> {
return visitScalarFunction(aesDecrypt, context);
}
default R visitAesDecryptV2(AesDecryptV2 aesDecryptV2, C context) {
return visitScalarFunction(aesDecryptV2, context);
}
default R visitAesEncrypt(AesEncrypt aesEncrypt, C context) {
return visitScalarFunction(aesEncrypt, context);
}
default R visitAesEncryptV2(AesEncryptV2 aesEncryptV2, C context) {
return visitScalarFunction(aesEncryptV2, context);
}
default R visitAppendTrailingCharIfAbsent(AppendTrailingCharIfAbsent function, C context) {
return visitScalarFunction(function, context);
}
@ -1774,18 +1762,10 @@ public interface ScalarFunctionVisitor<R, C> {
return visitScalarFunction(sm4Decrypt, context);
}
default R visitSm4DecryptV2(Sm4DecryptV2 sm4DecryptV2, C context) {
return visitScalarFunction(sm4DecryptV2, context);
}
default R visitSm4Encrypt(Sm4Encrypt sm4Encrypt, C context) {
return visitScalarFunction(sm4Encrypt, context);
}
default R visitSm4EncryptV2(Sm4EncryptV2 sm4EncryptV2, C context) {
return visitScalarFunction(sm4EncryptV2, context);
}
default R visitSpace(Space space, C context) {
return visitScalarFunction(space, context);
}