Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
yujun777 committed Nov 28, 2024
1 parent ec00b28 commit c661a67
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NumericLiteral;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
Expand Down Expand Up @@ -316,7 +318,13 @@ private static Expression processTypeRangeLimitComparison(ComparisonPredicate cp
left = ((Cast) left).child();
}

Optional<Pair<BigDecimal, BigDecimal>> minMaxOpt = TypeCoercionUtils.getDataTypeMinMaxValue(left.getDataType());
// cmp float like have lost precision
DataType leftType = left.getDataType();
if (!(leftType.isIntegerLikeType() || leftType.isDecimalV3Type())) {
return cp;
}

Optional<Pair<BigDecimal, BigDecimal>> minMaxOpt = TypeCoercionUtils.getDataTypeMinMaxValue(leftType);
if (!minMaxOpt.isPresent()) {
return cp;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1805,25 +1805,52 @@ private static Expression processDecimalV3BinaryArithmetic(BinaryArithmetic bina
*/
public static Optional<Pair<BigDecimal, BigDecimal>> getDataTypeMinMaxValue(DataType dataType) {
if (dataType.isTinyIntType()) {
return Optional.of(Pair.of(new BigDecimal(Byte.MIN_VALUE), new BigDecimal(Byte.MIN_VALUE)));
return Optional.of(Pair.of(new BigDecimal(Byte.MIN_VALUE), new BigDecimal(Byte.MAX_VALUE)));
} else if (dataType.isSmallIntType()) {
return Optional.of(Pair.of(new BigDecimal(Short.MIN_VALUE), new BigDecimal(Short.MIN_VALUE)));
return Optional.of(Pair.of(new BigDecimal(Short.MIN_VALUE), new BigDecimal(Short.MAX_VALUE)));
} else if (dataType.isIntegerType()) {
return Optional.of(Pair.of(new BigDecimal(Integer.MIN_VALUE), new BigDecimal(Integer.MIN_VALUE)));
return Optional.of(Pair.of(new BigDecimal(Integer.MIN_VALUE), new BigDecimal(Integer.MAX_VALUE)));
} else if (dataType.isBigIntType()) {
return Optional.of(Pair.of(new BigDecimal(Long.MIN_VALUE), new BigDecimal(Long.MIN_VALUE)));
return Optional.of(Pair.of(new BigDecimal(Long.MIN_VALUE), new BigDecimal(Long.MAX_VALUE)));
} else if (dataType.isLargeIntType()) {
return Optional.of(Pair.of(new BigDecimal(LargeIntType.MIN_VALUE), new BigDecimal(LargeIntType.MIN_VALUE)));
return Optional.of(Pair.of(new BigDecimal(LargeIntType.MIN_VALUE), new BigDecimal(LargeIntType.MAX_VALUE)));
} else if (dataType.isFloatType()) {
//minVal = BigDecimal.valueOf(-Float.MAX_VALUE);
return Optional.of(Pair.of(new BigDecimal(Float.MIN_VALUE), new BigDecimal(Float.MIN_VALUE)));
return Optional.of(Pair.of(new BigDecimal(String.valueOf(Float.MIN_VALUE)),
new BigDecimal(String.valueOf(Float.MAX_VALUE))));
} else if (dataType.isDoubleType()) {
//minVal = BigDecimal.valueOf(-Double.MAX_VALUE);
return Optional.of(Pair.of(new BigDecimal(Double.MIN_VALUE), new BigDecimal(Double.MIN_VALUE)));
} else if (dataType.isDecimalV2Type()) {
// DecimalV2Type type = (DecimalV2Type) leftType;
} else if (dataType.isDecimalV3Type()) {
// DecimalV3Type type = (DecimalV3Type) leftType;
return Optional.of(Pair.of(new BigDecimal(String.valueOf(Double.MIN_VALUE)),
new BigDecimal(String.valueOf(Double.MAX_VALUE))));
} else if (dataType.isDecimalLikeType()) {
int precision = -1;
int scale = -1;
if (dataType instanceof DecimalV2Type) {
DecimalV2Type type = (DecimalV2Type) dataType;
precision = type.getPrecision();
scale = type.getScale();
}
if (dataType instanceof DecimalV3Type) {
DecimalV3Type type = (DecimalV3Type) dataType;
precision = type.getPrecision();
scale = type.getScale();
}
if (scale >= 0) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < precision - scale; i ++) {
sb.append('9');
}
if (sb.length() == 0) {
sb.append('0');
}
if (scale > 0) {
sb.append('.');
for (int i = 0; i < scale; i++) {
sb.append('9');
}
}
return Optional.of(Pair.of(new BigDecimal("-" + sb.toString()), new BigDecimal(sb.toString())));
}
}

return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ private enum RangeLimitResult {
TRUE, // eval to true
FALSE, // eval to false
EQUALS, // eval to equals
NO_CONVERT // no convert
NO_CHANGE_CP // no change cmp type
}

@Test
Expand All @@ -315,36 +315,96 @@ void testTypeRangeLimit() {
));

checkTypeRangeLimit(TinyIntType.INSTANCE,
ImmutableList.of(new SmallIntLiteral((short) -129), new DoubleLiteral(-129.0), new DoubleLiteral(-128.1)),
ImmutableList.of(new TinyIntLiteral((byte) -128), new DoubleLiteral(-128.0)),
ImmutableList.of(new TinyIntLiteral((byte) -127), new DoubleLiteral(-127.0),
new TinyIntLiteral((byte) 126), new DoubleLiteral(126.0)),
ImmutableList.of(new TinyIntLiteral((byte) 127), new DoubleLiteral(127.0)),
ImmutableList.of(new SmallIntLiteral((short) 128), new DoubleLiteral(128.0), new DoubleLiteral(127.1)));
ImmutableList.of(
Pair.of(new SmallIntLiteral((short) -129), null),
Pair.of(new DecimalV3Literal(new BigDecimal("-129")), null),
Pair.of(new DecimalV3Literal(new BigDecimal("-128.1")), null),
Pair.of(new DecimalV3Literal(new BigDecimal("-1000.1")), null),
Pair.of(new DoubleLiteral(-129.0), new SmallIntLiteral((short) -129)),
Pair.of(new DoubleLiteral(-128.1), new DecimalV3Literal(new BigDecimal("-128.1")))),
ImmutableList.of(
Pair.of(new TinyIntLiteral((byte) -128), null),
Pair.of(new DecimalV3Literal(new BigDecimal("-128")), new TinyIntLiteral((byte) -128)),
Pair.of(new DoubleLiteral(-128.0), new TinyIntLiteral((byte) -128))),
ImmutableList.of(
Pair.of(new TinyIntLiteral((byte) -127), null),
Pair.of(new DecimalV3Literal(new BigDecimal("-127")), new TinyIntLiteral((byte) -127)),
Pair.of(new DoubleLiteral(-127.0), new TinyIntLiteral((byte) -127)),
Pair.of(new TinyIntLiteral((byte) 126), null),
Pair.of(new DoubleLiteral(126.0), new TinyIntLiteral((byte) 126))),
ImmutableList.of(
Pair.of(new TinyIntLiteral((byte) 127), null),
Pair.of(new DecimalV3Literal(new BigDecimal("127")), new TinyIntLiteral((byte) 127)),
Pair.of(new DecimalV3Literal(new BigDecimal("127.00")), new TinyIntLiteral((byte) 127)),
Pair.of(new DoubleLiteral(127.0), new TinyIntLiteral((byte) 127))),
ImmutableList.of(
Pair.of(new SmallIntLiteral((short) 128), null),
Pair.of(new DecimalV3Literal(new BigDecimal("128.02")), null),
Pair.of(new DoubleLiteral(128.0), new SmallIntLiteral((short) 128)),
Pair.of(new DoubleLiteral(127.1), new DecimalV3Literal(new BigDecimal("127.1")))));

checkTypeRangeLimit(DecimalV3Type.createDecimalV3Type(5, 2),
ImmutableList.of(
Pair.of(new DecimalV3Literal(new BigDecimal("-999.999")), null),
Pair.of(new DecimalV3Literal(new BigDecimal("-1000.00")), null),
Pair.of(new DecimalV3Literal(new BigDecimal("-1000.0123")), null)),
ImmutableList.of(
Pair.of(new DecimalV3Literal(new BigDecimal("-999.99")), null)),
ImmutableList.of(
Pair.of(new DecimalV3Literal(new BigDecimal("100.4")), null),
Pair.of(new DecimalV3Literal(new BigDecimal("100")), null)),
ImmutableList.of(
Pair.of(new DecimalV3Literal(new BigDecimal("999.99")), null)),
ImmutableList.of(
Pair.of(new DecimalV3Literal(new BigDecimal("1000")), null),
Pair.of(new DecimalV3Literal(new BigDecimal("999.999")), null)));

// cmp float / double lost precision
// checkTypeRangeLimit(FloatType.INSTANCE,
// ImmutableList.of(
// Pair.of(new DecimalV3Literal(new BigDecimal((double) Float.MIN_VALUE - 1)), null),
// Pair.of(new DoubleLiteral((double) Float.MIN_VALUE - 1), null)),
// ImmutableList.of(
// Pair.of(new FloatLiteral(Float.MIN_VALUE), null)
// Pair.of(new DoubleLiteral((double) Float.MIN_VALUE), null)
// ),
// ImmutableList.of(
// Pair.of(new FloatLiteral(Float.MIN_VALUE + 0.001f), null),
// Pair.of(new FloatLiteral(Float.MAX_VALUE - 0.001f), null),
// Pair.of(new IntegerLiteral(100), null)),
// ImmutableList.of(
// Pair.of(new FloatLiteral(Float.MAX_VALUE), null),
// Pair.of(new DoubleLiteral((double) Float.MAX_VALUE), null)),
// ImmutableList.of(
// Pair.of(new DoubleLiteral((double) Float.MAX_VALUE + 0.1), null)));
}

void checkTypeRangeLimit(DataType dataType, List<Expression> lessThanMinExpr, List<Expression> minExpr,
List<Expression> betweenMinMaxExpr, List<Expression> maxExpr, List<Expression> greaterThanMaxExpr) {
void checkTypeRangeLimit(DataType dataType, List<Pair<Expression, Expression>> lessThanMinExpr,
List<Pair<Expression, Expression>> minExpr, List<Pair<Expression, Expression>> betweenMinMaxExpr,
List<Pair<Expression, Expression>> maxExpr, List<Pair<Expression, Expression>> greaterThanMaxExpr) {
// due to ComparisonPredicate constructor require not null left and right child,
// use a dummyExpr as ComparisonPredicate's child
Expression dummyExpr = new SmallIntLiteral((short) 100);
// cp -> list of cp with lessThanMinExpr, minExpr, betweenMinMaxExpr, maxExpr, greaterThanMaxExpr
List<Pair<ComparisonPredicate, List<RangeLimitResult>>> cmpResults = ImmutableList.of(
Pair.of(new EqualTo(null, null), ImmutableList.of(
RangeLimitResult.FALSE, RangeLimitResult.NO_CONVERT, RangeLimitResult.NO_CONVERT,
RangeLimitResult.NO_CONVERT, RangeLimitResult.FALSE)),
Pair.of(new NullSafeEqual(null, null), ImmutableList.of(
RangeLimitResult.FALSE, RangeLimitResult.NO_CONVERT, RangeLimitResult.NO_CONVERT,
RangeLimitResult.NO_CONVERT, RangeLimitResult.FALSE)),
Pair.of(new GreaterThan(null, null), ImmutableList.of(
RangeLimitResult.TRUE, RangeLimitResult.NO_CONVERT, RangeLimitResult.NO_CONVERT,
RangeLimitResult.FALSE, RangeLimitResult.FALSE)),
Pair.of(new GreaterThanEqual(null, null), ImmutableList.of(
RangeLimitResult.TRUE, RangeLimitResult.TRUE, RangeLimitResult.NO_CONVERT,
RangeLimitResult.EQUALS, RangeLimitResult.FALSE)),
Pair.of(new LessThan(null, null), ImmutableList.of(
RangeLimitResult.FALSE, RangeLimitResult.FALSE, RangeLimitResult.NO_CONVERT,
RangeLimitResult.NO_CONVERT, RangeLimitResult.TRUE)),
Pair.of(new LessThanEqual(null, null), ImmutableList.of(
RangeLimitResult.FALSE, RangeLimitResult.EQUALS, RangeLimitResult.NO_CONVERT,
RangeLimitResult.TRUE, RangeLimitResult.TRUE))
Pair.of(new EqualTo(dummyExpr, dummyExpr), ImmutableList.of(
RangeLimitResult.FALSE, RangeLimitResult.NO_CHANGE_CP, RangeLimitResult.NO_CHANGE_CP,
RangeLimitResult.NO_CHANGE_CP, RangeLimitResult.FALSE)),
Pair.of(new NullSafeEqual(dummyExpr, dummyExpr), ImmutableList.of(
RangeLimitResult.FALSE, RangeLimitResult.NO_CHANGE_CP, RangeLimitResult.NO_CHANGE_CP,
RangeLimitResult.NO_CHANGE_CP, RangeLimitResult.FALSE)),
Pair.of(new GreaterThan(dummyExpr, dummyExpr), ImmutableList.of(
RangeLimitResult.TRUE, RangeLimitResult.NO_CHANGE_CP, RangeLimitResult.NO_CHANGE_CP,
RangeLimitResult.FALSE, RangeLimitResult.FALSE)),
Pair.of(new GreaterThanEqual(dummyExpr, dummyExpr), ImmutableList.of(
RangeLimitResult.TRUE, RangeLimitResult.TRUE, RangeLimitResult.NO_CHANGE_CP,
RangeLimitResult.EQUALS, RangeLimitResult.FALSE)),
Pair.of(new LessThan(dummyExpr, dummyExpr), ImmutableList.of(
RangeLimitResult.FALSE, RangeLimitResult.FALSE, RangeLimitResult.NO_CHANGE_CP,
RangeLimitResult.NO_CHANGE_CP, RangeLimitResult.TRUE)),
Pair.of(new LessThanEqual(dummyExpr, dummyExpr), ImmutableList.of(
RangeLimitResult.FALSE, RangeLimitResult.EQUALS, RangeLimitResult.NO_CHANGE_CP,
RangeLimitResult.TRUE, RangeLimitResult.TRUE))
);

for (Pair<ComparisonPredicate, List<RangeLimitResult>> cmpResult : cmpResults) {
Expand All @@ -358,36 +418,46 @@ void checkTypeRangeLimit(DataType dataType, List<Expression> lessThanMinExpr, Li
}
}

void checkTypeRangeLimitWithComparison(DataType dataType, ComparisonPredicate cp, List<Expression> exprs,
RangeLimitResult result) {
void checkTypeRangeLimitWithComparison(DataType dataType, ComparisonPredicate cp,
List<Pair<Expression, Expression>> exprs, RangeLimitResult result) {
Expression slot = new SlotReference("slot", dataType, true);
for (Expression right : exprs) {
for (int i = 0; i < 2; i++) {
Expression left = slot;
if (i == 1) {
left = new Cast(slot, right.getDataType());
}
Expression originExpr = cp.withChildren(left, right);
Expression rewrittenExpr = executor.rewrite(originExpr, context);
Expression expectExpr = null;
switch (result) {
case TRUE:
expectExpr = cp instanceof NullSafeEqual ? BooleanLiteral.TRUE
: ExpressionUtils.getTrue(left);
break;
case FALSE:
expectExpr = cp instanceof NullSafeEqual ? BooleanLiteral.FALSE
: ExpressionUtils.getFalse(left);
break;
case EQUALS:
expectExpr = new EqualTo(left, right);
break;
case NO_CONVERT:
expectExpr = originExpr;
break;
default:
Assertions.assertTrue(false);
}
for (Pair<Expression, Expression> pair : exprs) {
Expression right = pair.first;
Expression rewriteRight = pair.second;
if (rewriteRight == null) {
rewriteRight = right;
}
Expression left = slot;
if (!left.getDataType().equals(right.getDataType())) {
left = new Cast(slot, right.getDataType());
}
Expression originExpr = cp.withChildren(left, right);
Expression rewrittenExpr = executor.rewrite(originExpr, context);
Expression expectExpr = null;
// System.out.println("origin expr: " + originExpr);
// System.out.println("rewrite expr: " + rewrittenExpr);
switch (result) {
case TRUE:
expectExpr = cp instanceof NullSafeEqual ? BooleanLiteral.TRUE
: ExpressionUtils.getTrue(slot);
break;
case FALSE:
expectExpr = cp instanceof NullSafeEqual ? BooleanLiteral.FALSE
: ExpressionUtils.getFalse(slot);
break;
case EQUALS:
expectExpr = new EqualTo(slot, rewriteRight);
break;
case NO_CHANGE_CP:
Assertions.assertInstanceOf(cp.getClass(), rewrittenExpr);
if (!(slot.getDataType().isDecimalV3Type() && right.getDataType().isDecimalV3Type())) {
Assertions.assertEquals(slot, rewrittenExpr.child(0));
}
break;
default:
Assertions.assertTrue(false);
}
if (expectExpr != null) {
Assertions.assertEquals(expectExpr, rewrittenExpr);
}
}
Expand Down

0 comments on commit c661a67

Please sign in to comment.