diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index cb61795865239b..4cc2ed4874f2a6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -17,18 +17,17 @@ package org.apache.doris.nereids.rules.expression.rules; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; -import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; -import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; import org.apache.doris.nereids.trees.expressions.NullSafeEqual; @@ -38,22 +37,19 @@ import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral; import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal; import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal; +import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.LargeIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; -import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; -import org.apache.doris.nereids.types.BooleanType; -import org.apache.doris.nereids.types.DateTimeType; -import org.apache.doris.nereids.types.DateTimeV2Type; -import org.apache.doris.nereids.types.DateType; -import org.apache.doris.nereids.types.DateV2Type; -import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.*; import org.apache.doris.nereids.types.coercion.DateLikeType; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.base.Preconditions; @@ -62,9 +58,10 @@ import java.math.BigDecimal; import java.math.RoundingMode; import java.util.List; +import java.util.Optional; /** - * simplify comparison + * simplify comparison, not support large int. * such as: cast(c1 as DateV2) >= DateV2Literal --> c1 >= DateLiteral * cast(c1 AS double) > 2.0 --> c1 >= 2 (c1 is integer like type) */ @@ -98,22 +95,24 @@ public static Expression simplify(ComparisonPredicate cp) { Expression left = cp.left(); Expression right = cp.right(); - // float like type: float, double - if (left.getDataType().isFloatLikeType() && right.getDataType().isFloatLikeType()) { - return processFloatLikeTypeCoercion(cp, left, right); - } + Expression result; - // decimalv3 type - if (left.getDataType() instanceof DecimalV3Type && right.getDataType() instanceof DecimalV3Type) { - return processDecimalV3TypeCoercion(cp, left, right); + // process type coercion + if (left.getDataType().isFloatLikeType() && right.getDataType().isFloatLikeType()) { + result = processFloatLikeTypeCoercion(cp, left, right); + } else if (left.getDataType() instanceof DecimalV3Type && right.getDataType() instanceof DecimalV3Type) { + result = processDecimalV3TypeCoercion(cp, left, right); + } else if (left.getDataType() instanceof DateLikeType && right.getDataType() instanceof DateLikeType) { + result = processDateLikeTypeCoercion(cp, left, right); + } else { + result = cp; } - // date like type - if (left.getDataType() instanceof DateLikeType && right.getDataType() instanceof DateLikeType) { - return processDateLikeTypeCoercion(cp, left, right); + if (result instanceof ComparisonPredicate) { + result = processTypeRangeLimitComparision((ComparisonPredicate) result); } - return cp; + return result; } private static Expression processComparisonPredicateDateTimeV2Literal( @@ -128,17 +127,13 @@ private static Expression processComparisonPredicateDateTimeV2Literal( if (right.getMicroSecond() == originValue) { return comparisonPredicate.withChildren(left, right); } else { - if (left.nullable()) { - // TODO: the ideal way is to return an If expr like: - // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), - // BooleanLiteral.of(false)); - // but current fold constant rule can't handle such complex expr with null literal - // before supporting complex conjuncts with null literal folding rules, - // we use a trick way like this: - return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE)); - } else { - return BooleanLiteral.of(false); - } + // TODO: the ideal way is to return an If expr like: + // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), + // BooleanLiteral.of(false)); + // but current fold constant rule can't handle such complex expr with null literal + // before supporting complex conjuncts with null literal folding rules, + // we use a trick way like this: + return ExpressionUtils.getFalse(left); } } else if (comparisonPredicate instanceof NullSafeEqual) { long originValue = right.getMicroSecond(); @@ -239,18 +234,13 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa comparisonPredicate.withChildren(left, new DecimalV3Literal( literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY)))); } catch (ArithmeticException e) { - if (left.nullable()) { - // TODO: the ideal way is to return an If expr like: - // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), - // BooleanLiteral.of(false)); - // but current fold constant rule can't handle such complex expr with null literal - // before supporting complex conjuncts with null literal folding rules, - // we use a trick way like this: - return new And(new IsNull(left), - new NullLiteral(BooleanType.INSTANCE)); - } else { - return BooleanLiteral.of(false); - } + // TODO: the ideal way is to return an If expr like: + // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), + // BooleanLiteral.of(false)); + // but current fold constant rule can't handle such complex expr with null literal + // before supporting complex conjuncts with null literal folding rules, + // we use a trick way like this: + return ExpressionUtils.getFalse(left); } } else if (comparisonPredicate instanceof NullSafeEqual) { try { @@ -281,21 +271,18 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa private static Expression processIntegerDecimalLiteralComparison( ComparisonPredicate comparisonPredicate, Expression left, BigDecimal literal) { // we only process isIntegerLikeType, which are tinyint, smallint, int, bigint - if (literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) { + if (literal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0 + && literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) { literal = literal.stripTrailingZeros(); if (literal.scale() > 0) { if (comparisonPredicate instanceof EqualTo) { - if (left.nullable()) { - // TODO: the ideal way is to return an If expr like: - // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), - // BooleanLiteral.of(false)); - // but current fold constant rule can't handle such complex expr with null literal - // before supporting complex conjuncts with null literal folding rules, - // we use a trick way like this: - return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE)); - } else { - return BooleanLiteral.of(false); - } + // TODO: the ideal way is to return an If expr like: + // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), + // BooleanLiteral.of(false)); + // but current fold constant rule can't handle such complex expr with null literal + // before supporting complex conjuncts with null literal folding rules, + // we use a trick way like this: + return ExpressionUtils.getFalse(left); } else if (comparisonPredicate instanceof NullSafeEqual) { return BooleanLiteral.of(false); } else if (comparisonPredicate instanceof GreaterThan @@ -320,10 +307,95 @@ private static Expression processIntegerDecimalLiteralComparison( return comparisonPredicate; } + private static Expression processTypeRangeLimitComparision(ComparisonPredicate cp) { + Expression right = cp.right(); + Expression left = cp.left(); + if (left instanceof Cast) { + left = ((Cast) left).child(); + } + + BigDecimal literal = null; + if (right instanceof TinyIntLiteral) { + literal = new BigDecimal(((TinyIntLiteral) right).getValue()); + } else if (right instanceof SmallIntLiteral) { + literal = new BigDecimal(((SmallIntLiteral) right).getValue()); + } else if (right instanceof IntegerLiteral) { + literal = new BigDecimal(((IntegerLiteral) right).getValue()); + } else if (right instanceof BigIntLiteral) { + literal = new BigDecimal(((BigIntLiteral) right).getValue()); + } else if (right instanceof LargeIntLiteral) { + literal = new BigDecimal(((LargeIntLiteral) right).getValue()); + } else if (right instanceof FloatLiteral) { + literal = new BigDecimal(((FloatLiteral) right).getValue()); + } else if (right instanceof DoubleLiteral) { + literal = new BigDecimal(((DoubleLiteral) right).getValue()); + } else if (right instanceof DecimalLiteral) { + literal = ((DecimalLiteral) right).getValue(); + } else if (right instanceof DecimalV3Literal) { + literal = ((DecimalV3Literal) right).getValue(); + } + + if (literal == null) { + return cp; + } + + Optional> minMaxOpt = TypeCoercionUtils.getDataTypeMinMaxValue(left.getDataType()); + if (!minMaxOpt.isPresent()) { + return cp; + } + int cmpMin = literal.compareTo(minMaxOpt.get().first); + int cmpMax = literal.compareTo(minMaxOpt.get().second); + if (cp instanceof EqualTo) { + if (cmpMin < 0 || cmpMax > 0) { + return ExpressionUtils.getFalse(left); + } + } else if (cp instanceof NullSafeEqual) { + if (cmpMin < 0 || cmpMax > 0) { + return BooleanLiteral.of(false); + } + } else if (cp instanceof GreaterThan) { + if (cmpMin < 0) { + return ExpressionUtils.getTrue(left); + } + if (cmpMax >= 0) { + return ExpressionUtils.getFalse(left); + } + } else if (cp instanceof GreaterThanEqual) { + if (cmpMin <= 0) { + return ExpressionUtils.getTrue(left); + } + if (cmpMax == 0) { + return new EqualTo(cp.left(), cp.right()); + } + if (cmpMax > 0) { + return ExpressionUtils.getFalse(left); + } + } else if (cp instanceof LessThan) { + if (cmpMin <= 0) { + return ExpressionUtils.getFalse(left); + } + if (cmpMax > 0) { + return ExpressionUtils.getTrue(left); + } + } else if (cp instanceof LessThanEqual) { + if (cmpMin < 0) { + return ExpressionUtils.getFalse(left); + } + if (cmpMin == 0) { + return new EqualTo(cp.left(), cp.right()); + } + if (cmpMax >= 0) { + return ExpressionUtils.getTrue(left); + } + } + return cp; + } + private static IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) { - Preconditions.checkArgument( - decimal.scale() <= 0 && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0, - "decimal literal must have 0 scale and smaller than Long.MAX_VALUE"); + Preconditions.checkArgument(decimal.scale() <= 0 + && decimal.compareTo(new BigDecimal(Long.MIN_VALUE)) >= 0 + && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0, + "decimal literal must have 0 scale and in range [Long.MIN_VALUE, Long.MAX_VALUE]"); long val = decimal.longValue(); if (val >= Byte.MIN_VALUE && val <= Byte.MAX_VALUE) { return new TinyIntLiteral((byte) val); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index bf4d6e084795f1..f702ce99789841 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -160,6 +160,22 @@ public static Optional optionalAnd(List expressions) { } } + public static Expression getFalse(Expression expression) { + if (expression.nullable()) { + return new And(new IsNull(expression), new NullLiteral(BooleanType.INSTANCE)); + } else { + return BooleanLiteral.FALSE; + } + } + + public static Expression getTrue(Expression expression) { + if (expression.nullable()) { + return new Or(new Not(new IsNull(expression)), new NullLiteral(BooleanType.INSTANCE)); + } else { + return BooleanLiteral.TRUE; + } + } + /** * And two list. */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index d7f9fc83baf288..7dc2db1e332af6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -21,6 +21,7 @@ import org.apache.doris.catalog.ScalarType; import org.apache.doris.catalog.Type; import org.apache.doris.common.Config; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Add; @@ -1796,6 +1797,32 @@ private static Expression processDecimalV3BinaryArithmetic(BinaryArithmetic bina castIfNotSameType(right, dt2)); } + public static Optional> getDataTypeMinMaxValue(DataType dataType) { + if (dataType.isTinyIntType()) { + return Optional.of(Pair.of(new BigDecimal(Byte.MIN_VALUE), new BigDecimal(Byte.MIN_VALUE))); + } else if (dataType.isSmallIntType()) { + return Optional.of(Pair.of(new BigDecimal(Short.MIN_VALUE), new BigDecimal(Short.MIN_VALUE))); + } else if (dataType.isIntegerType()) { + return Optional.of(Pair.of(new BigDecimal(Integer.MIN_VALUE), new BigDecimal(Integer.MIN_VALUE))); + } else if (dataType.isBigIntType()) { + return Optional.of(Pair.of(new BigDecimal(Long.MIN_VALUE), new BigDecimal(Long.MIN_VALUE))); + } else if (dataType.isLargeIntType()) { + return Optional.of(Pair.of(new BigDecimal(LargeIntType.MIN_VALUE), new BigDecimal(LargeIntType.MIN_VALUE))); + } else if (dataType.isFloatType()) { + //minVal = BigDecimal.valueOf(-Float.MAX_VALUE); + return Optional.of(Pair.of(new BigDecimal(Float.MIN_VALUE), new BigDecimal(Float.MIN_VALUE))); + } else if (dataType.isDoubleType()) { + //minVal = BigDecimal.valueOf(-Double.MAX_VALUE); + return Optional.of(Pair.of(new BigDecimal(Double.MIN_VALUE), new BigDecimal(Double.MIN_VALUE))); + } else if (dataType.isDecimalV2Type()) { + // DecimalV2Type type = (DecimalV2Type) leftType; + } else if (dataType.isDecimalV3Type()) { + // DecimalV3Type type = (DecimalV3Type) leftType; + } + + return Optional.empty(); + } + private static boolean supportCompare(DataType dataType) { if (dataType.isArrayType()) { return true;