branch-2.1: [fix](nereids) fix compare ipv4 / ipv6 always equals (#47514)

### What problem does this PR solve?

fix compare ipv4 and ipv6 always equals.

for example, `select cast('127.0.0.1' as ipv4) = cast('192.168.10.10' as
ipv4)'` will return 1, but it should return 0;

Issue Number: close #xxx
This commit is contained in:
yujun
2025-02-09 04:43:07 +08:00
committed by GitHub
parent 3f250e55ce
commit 8ff4ae879e
12 changed files with 309 additions and 13 deletions

View File

@ -157,6 +157,12 @@ under the License.
<artifactId>guava-testlib</artifactId>
<scope>test</scope>
</dependency>
<!-- https://mvnrepository.com/artifact/com.googlecode.java-ipv6/java-ipv6 -->
<dependency>
<groupId>com.googlecode.java-ipv6</groupId>
<artifactId>java-ipv6</artifactId>
<version>0.17</version>
</dependency>
<!-- https://mvnrepository.com/artifact/com.fasterxml.jackson.core/jackson-core -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>

View File

@ -666,7 +666,7 @@ public class DateLiteral extends LiteralExpr {
return diff < 0 ? -1 : (diff == 0 ? 0 : 1);
}
// date time will not overflow when doing addition and subtraction
return getStringValue().compareTo(expr.getStringValue());
return Integer.signum(getStringValue().compareTo(expr.getStringValue()));
}
@Override

View File

@ -245,6 +245,9 @@ public class DecimalLiteral extends NumericLiteralExpr {
if (expr instanceof NullLiteral) {
return 1;
}
if (expr == MaxLiteral.MAX_VALUE) {
return -1;
}
if (expr instanceof DecimalLiteral) {
return this.value.compareTo(((DecimalLiteral) expr).value);
} else {

View File

@ -121,6 +121,9 @@ public class FloatLiteral extends NumericLiteralExpr {
if (expr instanceof NullLiteral) {
return 1;
}
if (expr == MaxLiteral.MAX_VALUE) {
return -1;
}
return Double.compare(value, expr.getDoubleValue());
}

View File

@ -127,8 +127,18 @@ public class IPv4Literal extends LiteralExpr {
}
@Override
public int compareLiteral(LiteralExpr expr) {
return 0;
public int compareLiteral(LiteralExpr other) {
if (other instanceof IPv4Literal) {
return Long.compare(value, ((IPv4Literal) other).value);
}
if (other instanceof NullLiteral) {
return 1;
}
if (other instanceof MaxLiteral) {
return -1;
}
throw new RuntimeException("Cannot compare two values with different data types: "
+ this + " (" + getClass() + ") vs " + other + " (" + other.getClass() + ")");
}
@Override

View File

@ -24,6 +24,10 @@ import org.apache.doris.thrift.TExprNode;
import org.apache.doris.thrift.TExprNodeType;
import org.apache.doris.thrift.TIPv6Literal;
import com.google.common.base.Suppliers;
import com.googlecode.ipv6.IPv6Address;
import java.util.function.Supplier;
import java.util.regex.Pattern;
public class IPv6Literal extends LiteralExpr {
@ -37,6 +41,8 @@ public class IPv6Literal extends LiteralExpr {
private String value;
private Supplier<IPv6Address> ipv6Value = Suppliers.memoize(() -> IPv6Address.fromString(value));
/**
* C'tor forcing type, e.g., due to implicit cast
*/
@ -92,8 +98,18 @@ public class IPv6Literal extends LiteralExpr {
}
@Override
public int compareLiteral(LiteralExpr expr) {
return 0;
public int compareLiteral(LiteralExpr other) {
if (other instanceof IPv6Literal) {
return ipv6Value.get().compareTo(((IPv6Literal) other).ipv6Value.get());
}
if (other instanceof NullLiteral) {
return 1;
}
if (other instanceof MaxLiteral) {
return -1;
}
throw new RuntimeException("Cannot compare two values with different data types: "
+ this + " (" + getClass() + ") vs " + other + " (" + other.getClass() + ")");
}
@Override

View File

@ -257,17 +257,13 @@ public class IntLiteral extends NumericLiteralExpr {
if (expr instanceof NullLiteral) {
return 1;
}
if (expr instanceof StringLiteral) {
return ((StringLiteral) expr).compareLiteral(this);
}
if (expr == MaxLiteral.MAX_VALUE) {
return -1;
}
if (value == expr.getLongValue()) {
return 0;
} else {
return value > expr.getLongValue() ? 1 : -1;
if (expr instanceof StringLiteral) {
return - ((StringLiteral) expr).compareLiteral(this);
}
return Long.compare(value, expr.getLongValue());
}
@Override

View File

@ -64,6 +64,8 @@ public abstract class ComparisonPredicate extends BinaryOperator {
for (Expression c : children) {
if (c.getDataType().isComplexType() && !c.getDataType().isArrayType()) {
throw new AnalysisException("comparison predicate could not contains complex type: " + this.toSql());
} else if (c.getDataType().isJsonType()) {
throw new AnalysisException("comparison predicate could not contains json type: " + this.toSql());
}
}
}

View File

@ -96,7 +96,7 @@ public class IndexesProcNodeTest {
Assert.assertEquals(procResult.getRows().get(3).get(5), "col_4");
Assert.assertEquals(procResult.getRows().get(3).get(11), "NGRAM_BF");
Assert.assertEquals(procResult.getRows().get(3).get(12), "ngram_bf index on col_4");
Assert.assertEquals(procResult.getRows().get(3).get(13), "(\"gram_size\" = \"3\", \"bf_size\" = \"256\")");
Assert.assertEquals(procResult.getRows().get(3).get(13), "(\"bf_size\" = \"256\", \"gram_size\" = \"3\")");
}
}

View File

@ -0,0 +1,97 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.literal;
import org.apache.doris.common.ExceptionChecker;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
class CompareLiteralTest extends TestWithFeService {
@Test
public void testScalar() {
// ip type
checkCompareSameType(0, new IPv4Literal("170.0.0.100"), new IPv4Literal("170.0.0.100"));
checkCompareSameType(1, new IPv4Literal("170.0.0.100"), new IPv4Literal("160.0.0.200"));
checkCompareDiffType(new IPv4Literal("172.0.0.100"), new IPv6Literal("1080:0:0:0:8:800:200C:417A"));
checkCompareSameType(0, new IPv6Literal("1080:0:0:0:8:800:200C:417A"), new IPv6Literal("1080:0:0:0:8:800:200C:417A"));
checkCompareSameType(1, new IPv6Literal("1080:0:0:0:8:800:200C:417A"), new IPv6Literal("1000:0:0:0:8:800:200C:41AA"));
IPv4Literal ipv4 = new IPv4Literal("170.0.0.100");
Assertions.assertEquals(ipv4, new IPv4Literal(ipv4.toLegacyLiteral().getStringValue()));
IPv6Literal ipv6 = new IPv6Literal("1080:0:0:0:8:800:200C:417A");
Assertions.assertEquals(ipv6, new IPv6Literal(ipv6.toLegacyLiteral().getStringValue()));
}
@Test
public void testComplex() throws Exception {
// array type
checkCompareSameType(0,
new ArrayLiteral(ImmutableList.of(new IntegerLiteral(100), new IntegerLiteral(200))),
new ArrayLiteral(ImmutableList.of(new IntegerLiteral(100), new IntegerLiteral(200))));
checkCompareSameType(1,
new ArrayLiteral(ImmutableList.of(new IntegerLiteral(200))),
new ArrayLiteral(ImmutableList.of(new IntegerLiteral(100), new IntegerLiteral(200))));
checkCompareSameType(1,
new ArrayLiteral(ImmutableList.of(new IntegerLiteral(100), new IntegerLiteral(200), new IntegerLiteral(1))),
new ArrayLiteral(ImmutableList.of(new IntegerLiteral(100), new IntegerLiteral(200))));
checkComparableNoException("select array(1,2) = array(1, 2)");
checkComparableNoException("select array(1,2) > array(1, 2)");
// json type
// checkNotComparable("select cast ('[1, 2]' as json) = cast('[1, 2]' as json)",
// "comparison predicate could not contains json type");
// checkNotComparable("select cast('[1, 2]' as json) > cast('[1, 2]' as json)",
// "comparison predicate could not contains json type");
// map type
checkNotComparable("select map(1, 2) = map(1, 2)",
"can not cast from origin type map<tinyint,tinyint> to target type=double");
checkNotComparable("select map(1, 2) > map(1, 2)",
"can not cast from origin type map<tinyint,tinyint> to target type=double");
checkNotComparable("select cast('(1, 2)' as map<int, int>) = cast('(1, 2)' as map<int, int>)",
"can not cast from origin type map<int,int> to target type=double");
// struct type
checkNotComparable("select struct(1, 2) = struct(1, 2)",
"can not cast from origin type struct<col:tinyint,col:tinyint> to target type=double");
checkNotComparable("select struct(1, 2) > struct(1, 2)",
"can not cast from origin type struct<col:tinyint,col:tinyint> to target type=double");
}
private void checkCompareSameType(int expect, Literal left, Literal right) {
Assertions.assertEquals(expect, left.compareTo(right));
Assertions.assertEquals(- expect, right.compareTo(left));
}
private void checkCompareDiffType(Literal left, Literal right) {
Assertions.assertThrowsExactly(RuntimeException.class, () -> left.compareTo(right));
Assertions.assertThrowsExactly(RuntimeException.class, () -> right.compareTo(left));
}
private void checkComparableNoException(String sql) throws Exception {
ExceptionChecker.expectThrowsNoException(() -> executeSql(sql));
}
private void checkNotComparable(String sql, String expectErrMsg) throws Exception {
ExceptionChecker.expectThrowsWithMsg(IllegalStateException.class, expectErrMsg,
() -> executeSql(sql));
}
}

View File

@ -611,6 +611,16 @@ public abstract class TestWithFeService {
}
}
public void executeSql(String queryStr) throws Exception {
connectContext.getState().reset();
StmtExecutor stmtExecutor = new StmtExecutor(connectContext, queryStr);
stmtExecutor.execute();
if (connectContext.getState().getStateType() == QueryState.MysqlStateType.ERR
|| connectContext.getState().getErrorCode() != null) {
throw new IllegalStateException(connectContext.getState().getErrorMessage());
}
}
public void createDatabase(String db) throws Exception {
String createDbStmtStr = "CREATE DATABASE " + db;
CreateDbStmt createDbStmt = (CreateDbStmt) parseAndAnalyzeStmt(createDbStmtStr);

View File

@ -0,0 +1,153 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite('test_compare_literal') {
for (def val in [true, false]) {
sql "set debug_skip_fold_constant=${val}"
// ipv4
test {
sql "select cast('170.0.0.100' as ipv4) = cast('170.0.0.100' as ipv4)"
result([[true]])
}
test {
sql "select cast('170.0.0.100' as ipv4) >= cast('170.0.0.100' as ipv4)"
result([[true]])
}
test {
sql "select cast('170.0.0.100' as ipv4) > cast('170.0.0.100' as ipv4)"
result([[false]])
}
test {
sql "select cast('170.0.0.100' as ipv4) = cast('160.0.0.200' as ipv4)"
result([[false]])
}
test {
sql "select cast('170.0.0.100' as ipv4) >= cast('160.0.0.200' as ipv4)"
result([[true]])
}
test {
sql "select cast('170.0.0.100' as ipv4) > cast('160.0.0.200' as ipv4)"
result([[true]])
}
test {
sql "select cast('170.0.0.100' as ipv4) < cast('160.0.0.200' as ipv4)"
result([[false]])
}
// ipv6
test {
sql "select cast('1080:0:0:0:8:800:200C:417A' as ipv6) = cast('1080:0:0:0:8:800:200C:417A' as ipv6)"
result([[true]])
}
test {
sql "select cast('1080:0:0:0:8:800:200C:417A' as ipv6) >= cast('1080:0:0:0:8:800:200C:417A' as ipv6)"
result([[true]])
}
test {
sql "select cast('1080:0:0:0:8:800:200C:417A' as ipv6) > cast('1080:0:0:0:8:800:200C:417A' as ipv6)"
result([[false]])
}
test {
sql "select cast('1080:0:0:0:8:800:200C:417A' as ipv6) = cast('1000:0:0:0:8:800:200C:41AA' as ipv6)"
result([[false]])
}
test {
sql "select cast('1080:0:0:0:8:800:200C:417A' as ipv6) >= cast('1000:0:0:0:8:800:200C:41AA' as ipv6)"
result([[true]])
}
test {
sql "select cast('1080:0:0:0:8:800:200C:417A' as ipv6) > cast('1000:0:0:0:8:800:200C:41AA' as ipv6)"
result([[true]])
}
test {
sql "select cast('1080:0:0:0:8:800:200C:417A' as ipv6) < cast('1000:0:0:0:8:800:200C:41AA' as ipv6)"
result([[false]])
}
// array
test {
sql 'select array(5, 6) = array(5, 6)'
result([[true]])
}
test {
sql 'select array(5, 6) >= array(5, 6)'
result([[true]])
}
test {
sql 'select array(5, 6) > array(5, 6)'
result([[false]])
}
test {
sql 'select array(5, 6) = array(5, 7)'
result([[false]])
}
test {
sql 'select array(5, 6) >= array(5, 7)'
result([[false]])
}
test {
sql 'select array(5, 6) > array(5, 7)'
result([[false]])
}
test {
sql 'select array(5, 6) < array(5, 7)'
result([[true]])
}
test {
sql 'select array(5, 6) < array(5, 6, 1)'
result([[true]])
}
test {
sql 'select array(5, 6) < array(6)'
result([[true]])
}
}
// test not comparable
sql 'set debug_skip_fold_constant=false'
// json
// test {
// sql "select cast('[1, 2]' as json) = cast('[1, 2]' as json)"
// exception 'comparison predicate could not contains json type'
// }
// test {
// sql "select cast('[1, 2]' as json) > cast('[1, 2]' as json)"
// exception 'comparison predicate could not contains json type'
// }
// map
test {
sql 'select map(1, 2) = map(1, 2)'
exception 'can not cast from origin type map<tinyint,tinyint> to target type=double'
}
test {
sql 'select map(1, 2) > map(1, 2)'
exception 'can not cast from origin type map<tinyint,tinyint> to target type=double'
}
// struct
test {
sql 'select struct(1, 2) = struct(1, 2)'
exception 'can not cast from origin type struct<col:tinyint,col:tinyint> to target type=double'
}
test {
sql 'select struct(1, 2) > struct(1, 2)'
exception 'can not cast from origin type struct<col:tinyint,col:tinyint> to target type=double'
}
}