[feature](aes_encrypt) support GCM mode for aes_encrypt and aes_decrypt (#40004) (#40672)

pick #40004 to branch-2.1
This commit is contained in:
camby
2024-09-11 23:28:28 +08:00
committed by GitHub
parent bf156d1665
commit 361a59dec8
9 changed files with 357 additions and 77 deletions

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import com.google.common.collect.ImmutableSet;
@ -52,7 +53,16 @@ public abstract class AesCryptoFunction extends CryptoFunction {
"AES_256_CTR",
"AES_128_OFB",
"AES_192_OFB",
"AES_256_OFB"
"AES_256_OFB",
"AES_128_GCM",
"AES_192_GCM",
"AES_256_GCM"
);
public static final Set<String> AES_GCM_MODES = ImmutableSet.of(
"AES_128_GCM",
"AES_192_GCM",
"AES_256_GCM"
);
public AesCryptoFunction(String name, Expression... arguments) {
@ -72,4 +82,17 @@ public abstract class AesCryptoFunction extends CryptoFunction {
}
return encryptionMode;
}
@Override
public void checkLegalityAfterRewrite() {
if (arity() >= 4 && child(3) instanceof StringLikeLiteral) {
String mode = ((StringLikeLiteral) child(3)).getValue().toUpperCase();
if (!AES_MODES.contains(mode)) {
throw new AnalysisException("mode " + mode + " is not supported");
}
if (arity() == 5 && !AES_GCM_MODES.contains(mode)) {
throw new AnalysisException("only GCM mode support AAD(the 5th arg)");
}
}
}
}

View File

@ -50,7 +50,16 @@ public class AesDecrypt extends AesCryptoFunction {
VarcharType.SYSTEM_DEFAULT,
VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(StringType.INSTANCE)
.args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE)
.args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(VarcharType.SYSTEM_DEFAULT,
VarcharType.SYSTEM_DEFAULT,
VarcharType.SYSTEM_DEFAULT,
VarcharType.SYSTEM_DEFAULT,
VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(StringType.INSTANCE)
.args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE,
StringType.INSTANCE)
);
/**
@ -68,18 +77,25 @@ public class AesDecrypt extends AesCryptoFunction {
super("aes_decrypt", arg0, arg1, arg2, arg3);
}
public AesDecrypt(Expression arg0, Expression arg1, Expression arg2, Expression arg3, Expression arg4) {
super("aes_decrypt", arg0, arg1, arg2, arg3, arg4);
}
/**
* withChildren.
*/
@Override
public AesDecrypt withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() >= 2 && children.size() <= 4);
Preconditions.checkArgument(children.size() >= 2 && children.size() <= 5);
if (children.size() == 2) {
return new AesDecrypt(children.get(0), children.get(1));
} else if (children().size() == 3) {
return new AesDecrypt(children.get(0), children.get(1), children.get(2));
} else {
} else if (children().size() == 4) {
return new AesDecrypt(children.get(0), children.get(1), children.get(2), children.get(3));
} else {
return new AesDecrypt(children.get(0), children.get(1), children.get(2), children.get(3),
children.get(4));
}
}

View File

@ -50,7 +50,16 @@ public class AesEncrypt extends AesCryptoFunction {
VarcharType.SYSTEM_DEFAULT,
VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(StringType.INSTANCE)
.args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE)
.args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(VarcharType.SYSTEM_DEFAULT,
VarcharType.SYSTEM_DEFAULT,
VarcharType.SYSTEM_DEFAULT,
VarcharType.SYSTEM_DEFAULT,
VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(StringType.INSTANCE)
.args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE,
StringType.INSTANCE)
);
/**
@ -68,18 +77,25 @@ public class AesEncrypt extends AesCryptoFunction {
super("aes_encrypt", arg0, arg1, arg2, arg3);
}
public AesEncrypt(Expression arg0, Expression arg1, Expression arg2, Expression arg3, Expression arg4) {
super("aes_encrypt", arg0, arg1, arg2, arg3, arg4);
}
/**
* withChildren.
*/
@Override
public AesEncrypt withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() >= 2 && children.size() <= 4);
Preconditions.checkArgument(children.size() >= 2 && children.size() <= 5);
if (children.size() == 2) {
return new AesEncrypt(children.get(0), children.get(1));
} else if (children().size() == 3) {
return new AesEncrypt(children.get(0), children.get(1), children.get(2));
} else {
} else if (children().size() == 4) {
return new AesEncrypt(children.get(0), children.get(1), children.get(2), children.get(3));
} else {
return new AesEncrypt(children.get(0), children.get(1), children.get(2), children.get(3),
children.get(4));
}
}