diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/FilterSelectivityEstimator.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/FilterSelectivityEstimator.java index b18c525c8849..257a1e1747ec 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/FilterSelectivityEstimator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/FilterSelectivityEstimator.java @@ -18,18 +18,24 @@ package org.apache.hadoop.hive.ql.optimizer.calcite.stats; import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.ArrayList; import java.util.Collections; import java.util.GregorianCalendar; import java.util.List; +import java.util.Objects; +import java.util.Optional; import java.util.Set; +import com.google.common.collect.BoundType; +import com.google.common.collect.Range; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.plan.RelOptUtil.InputReferencedVisitor; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; @@ -184,91 +190,309 @@ public Double visitCall(RexCall call) { return selectivity; } - private double computeRangePredicateSelectivity(RexCall call, SqlKind op) { - final boolean isLiteralLeft = call.getOperands().get(0).getKind().equals(SqlKind.LITERAL); - final boolean isLiteralRight = call.getOperands().get(1).getKind().equals(SqlKind.LITERAL); - final boolean isInputRefLeft = call.getOperands().get(0).getKind().equals(SqlKind.INPUT_REF); - final boolean isInputRefRight = call.getOperands().get(1).getKind().equals(SqlKind.INPUT_REF); + /** + * Return whether the expression is a removable cast based on stats and type bounds. + * + *

+ * In Hive, if a value cannot be represented by the cast, the result of the cast is NULL, + * and therefore cannot fulfill the predicate. So the possible range of the values + * is limited by the range of possible values of the type. + *

+ * + * @param exp the expression to check + * @param tableScan the table that provides the statistics + * @return true if the expression is a removable cast, false otherwise + */ + private boolean isRemovableCast(RexNode exp, HiveTableScan tableScan) { + if(SqlKind.CAST != exp.getKind()) { + return false; + } + RexCall cast = (RexCall) exp; + RexNode op0 = cast.getOperands().getFirst(); + if (!(op0 instanceof RexInputRef)) { + return false; + } + int index = ((RexInputRef) op0).getIndex(); + final List colStats = tableScan.getColStat(Collections.singletonList(index)); + if (colStats.isEmpty()) { + return false; + } - if (childRel instanceof HiveTableScan && isLiteralLeft != isLiteralRight && isInputRefLeft != isInputRefRight) { - final HiveTableScan t = (HiveTableScan) childRel; - final int inputRefIndex = ((RexInputRef) call.getOperands().get(isInputRefLeft ? 0 : 1)).getIndex(); - final List colStats = t.getColStat(Collections.singletonList(inputRefIndex)); + // Check that the possible values of the input column are all within the type range of the cast + // otherwise the CAST introduces some modulo-like behavior + ColStatistics colStat = colStats.getFirst(); + ColStatistics.Range colRange = colStat.getRange(); + if (colRange == null || colRange.minValue == null || colRange.maxValue == null) { + return false; + } - if (!colStats.isEmpty() && isHistogramAvailable(colStats.get(0))) { - final KllFloatsSketch kll = KllFloatsSketch.heapify(Memory.wrap(colStats.get(0).getHistogram())); - final Object boundValueObject = ((RexLiteral) call.getOperands().get(isLiteralLeft ? 0 : 1)).getValue(); - final SqlTypeName typeName = call.getOperands().get(isInputRefLeft ? 0 : 1).getType().getSqlTypeName(); - float value = extractLiteral(typeName, boundValueObject); - boolean closedBound = op.equals(SqlKind.LESS_THAN_OR_EQUAL) || op.equals(SqlKind.GREATER_THAN_OR_EQUAL); - - double selectivity; - if (op.equals(SqlKind.LESS_THAN_OR_EQUAL) || op.equals(SqlKind.LESS_THAN)) { - selectivity = closedBound ? lessThanOrEqualSelectivity(kll, value) : lessThanSelectivity(kll, value); - } else { - selectivity = closedBound ? greaterThanOrEqualSelectivity(kll, value) : greaterThanSelectivity(kll, value); - } + SqlTypeName type = cast.getType().getSqlTypeName(); + + double min; + double max; + switch (type) { + case TINYINT, SMALLINT, INTEGER, BIGINT: + min = ((Number) type.getLimit(false, SqlTypeName.Limit.OVERFLOW, false, -1, -1)).doubleValue(); + max = ((Number) type.getLimit(true, SqlTypeName.Limit.OVERFLOW, false, -1, -1)).doubleValue(); + break; + case TIMESTAMP, DATE: + min = Long.MIN_VALUE; + max = Long.MAX_VALUE; + break; + case FLOAT: + min = -Float.MAX_VALUE; + max = Float.MAX_VALUE; + break; + case DOUBLE, DECIMAL: + min = -Double.MAX_VALUE; + max = Double.MAX_VALUE; + break; + default: + // unknown type, do not remove the cast + return false; + } + // are all values of the input column accepted by the cast? + return min < colRange.minValue.doubleValue() && colRange.maxValue.doubleValue() < max; + } + + /** + * Get the range of values that are rounded to valid values of a DECIMAL type. + * + * @param type the DECIMAL type + * @param lowerBound the lower bound type of the result + * @param upperBound the upper bound type of the result + * @return the range of the type + */ + private static Range getRangeOfDecimalType(RelDataType type, BoundType lowerBound, BoundType upperBound) { + // values outside the representable range are cast to NULL, so adapt the boundaries + int digits = type.getPrecision() - type.getScale(); + // the cast does some rounding, i.e., CAST(99.9499 AS DECIMAL(3,1)) = 99.9 + // but CAST(99.95 AS DECIMAL(3,1)) = NULL + float adjust = (float) (5 * Math.pow(10, -(type.getScale() + 1))); + // the range of values supported by the type is interval [-typeRangeExtent, typeRangeExtent] (both inclusive) + // e.g., the typeRangeExt is 99.94999 for DECIMAL(3,1) + float typeRangeExtent = Math.nextDown((float) (Math.pow(10, digits) - adjust)); + + // the resulting value of +- adjust would be rounded up, so in some cases we need to use Math.nextDown + boolean lowerInclusive = BoundType.CLOSED.equals(lowerBound); + boolean upperInclusive = BoundType.CLOSED.equals(upperBound); + float lowerUniverse = lowerInclusive ? -typeRangeExtent : Math.nextDown(-typeRangeExtent); + float upperUniverse = upperInclusive ? typeRangeExtent : Math.nextUp(typeRangeExtent); + return makeRange(lowerUniverse, lowerBound, upperUniverse, upperBound); + } + + /** + * Adjust the type boundaries if necessary. + * + *

+ * Special care is taken to support the cast to DECIMAL(precision, scale): + * The cast to DECIMAL rounds the value the same way as {@link RoundingMode#HALF_UP}. + * The boundaries are adjusted accordingly. + *

+ * + * @param predicateRange boundaries of the range predicate + * @param type the DECIMAL type + * @param typeRange the boundaries of the type range + * @return the adjusted boundary + */ + private static Range adjustRangeToDecimalType(Range predicateRange, RelDataType type, + Range typeRange) { + float adjust = (float) (5 * Math.pow(10, -(type.getScale() + 1))); + // the resulting value of +- adjust would be rounded up, so in some cases we need to use Math.nextDown + boolean lowerInclusive = BoundType.CLOSED.equals(predicateRange.lowerBoundType()); + boolean upperInclusive = BoundType.CLOSED.equals(predicateRange.upperBoundType()); + float adjusted1 = lowerInclusive ? predicateRange.lowerEndpoint() - adjust + : Math.nextDown(predicateRange.lowerEndpoint() + adjust); + float adjusted2 = upperInclusive ? Math.nextDown(predicateRange.upperEndpoint() + adjust) + : predicateRange.upperEndpoint() - adjust; + float lower = Math.max(adjusted1, typeRange.lowerEndpoint()); + float upper = Math.min(adjusted2, typeRange.upperEndpoint()); + // the boundaries might result in an invalid range (e.g., left > right) + // in that case the predicate does not select anything, and we return an empty range + return makeRange(lower, predicateRange.lowerBoundType(), upper, predicateRange.upperBoundType()); + } - // selectivity does not account for null values, we multiply for the number of non-null values (getN) - // and we divide by the total (non-null + null values) to get the overall selectivity. - // - // Example: consider a filter "col < 3", and the following table rows: - // _____ - // | col | - // |_____| - // |1 | - // |null | - // |null | - // |3 | - // |4 | - // ------- - // kll.getN() would be 3, selectivity 1/3, t.getTable().getRowCount() 5 - // so the final result would be 3 * 1/3 / 5 = 1/5, as expected. - return kll.getN() * selectivity / t.getTable().getRowCount(); + /** + * If the arguments lead to a valid range, it is returned, otherwise an empty array is returned. + */ + private static Range makeRange(float lower, BoundType lowerType, float upper, BoundType upperType) { + if (lower > upper) { + return Range.closedOpen(0f, 0f); + } + if (lower == upper && lowerType == BoundType.OPEN && upperType == BoundType.OPEN) { + return Range.closedOpen(0f, 0f); + } + + return Range.range(lower, lowerType, upper, upperType); + } + + private double computeRangePredicateSelectivity(RexCall call, SqlKind op) { + double defaultSelectivity = ((double) 1 / (double) 3); + if (!(childRel instanceof HiveTableScan)) { + return defaultSelectivity; + } + + // search for the literal + List operands = call.getOperands(); + final Optional leftLiteral = extractLiteral(operands.get(0)); + final Optional rightLiteral = extractLiteral(operands.get(1)); + // ensure that there's exactly one literal + if ((leftLiteral.isPresent()) == (rightLiteral.isPresent())) { + return defaultSelectivity; + } + int literalOpIdx = leftLiteral.isPresent() ? 0 : 1; + + // analyze the predicate + float value = leftLiteral.orElseGet(rightLiteral::get); + int boundaryIdx; + boolean openBound = op == SqlKind.LESS_THAN || op == SqlKind.GREATER_THAN; + switch (op) { + case LESS_THAN, LESS_THAN_OR_EQUAL: + boundaryIdx = literalOpIdx; + break; + case GREATER_THAN, GREATER_THAN_OR_EQUAL: + boundaryIdx = 1 - literalOpIdx; + break; + default: + return defaultSelectivity; + } + float[] boundaryValues = new float[] { Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY }; + BoundType[] inclusive = new BoundType[] { BoundType.CLOSED, BoundType.CLOSED }; + boundaryValues[boundaryIdx] = value; + inclusive[boundaryIdx] = openBound ? BoundType.OPEN : BoundType.CLOSED; + Range boundaries = Range.range(boundaryValues[0], inclusive[0], boundaryValues[1], inclusive[1]); + + // extract the column index from the other operator + final HiveTableScan scan = (HiveTableScan) childRel; + int inputRefOpIndex = 1 - literalOpIdx; + RexNode node = operands.get(inputRefOpIndex); + if (isRemovableCast(node, scan)) { + if (node.getType().getSqlTypeName() == SqlTypeName.DECIMAL) { + Range rangeOfDecimalType = + getRangeOfDecimalType(node.getType(), boundaries.lowerBoundType(), boundaries.upperBoundType()); + boundaries = adjustRangeToDecimalType(boundaries, node.getType(), rangeOfDecimalType); } + node = RexUtil.removeCast(node); + } + + int inputRefIndex = -1; + if (node.getKind().equals(SqlKind.INPUT_REF)) { + inputRefIndex = ((RexInputRef) node).getIndex(); + } + + if (inputRefIndex < 0) { + return defaultSelectivity; + } + + final List colStats = scan.getColStat(Collections.singletonList(inputRefIndex)); + if (colStats.isEmpty() || !isHistogramAvailable(colStats.get(0))) { + return defaultSelectivity; } - return ((double) 1 / (double) 3); + + final KllFloatsSketch kll = KllFloatsSketch.heapify(Memory.wrap(colStats.get(0).getHistogram())); + double rawSelectivity = rangedSelectivity(kll, boundaries); + return scaleSelectivityToNullableValues(kll, rawSelectivity, scan); + } + + /** + * Adjust the selectivity estimate to take NULL values into account. + *

+ * The rawSelectivity does not account for null values. We multiply with the number of non-null values (getN) + * and we divide by the total number (non-null + null values) to get the overall selectivity. + *

+ * Example: consider a filter "col < 3", and the following table rows: + *

+   *  _____
+   * | col |
+   * |_____|
+   * |1    |
+   * |null |
+   * |null |
+   * |3    |
+   * |4    |
+   * -------
+   * 
+ * kll.getN() would be 3, rawSelectivity 1/3, scan.getTable().getRowCount() 5 + * so the final result would be 3 * 1/3 / 5 = 1/5, as expected. + */ + private static double scaleSelectivityToNullableValues(KllFloatsSketch kll, double rawSelectivity, + HiveTableScan scan) { + if (scan.getTable() == null) { + return rawSelectivity; + } + return kll.getN() * rawSelectivity / scan.getTable().getRowCount(); } private Double computeBetweenPredicateSelectivity(RexCall call) { - final boolean hasLiteralBool = call.getOperands().get(0).getKind().equals(SqlKind.LITERAL); - final boolean hasInputRef = call.getOperands().get(1).getKind().equals(SqlKind.INPUT_REF); - final boolean hasLiteralLeft = call.getOperands().get(2).getKind().equals(SqlKind.LITERAL); - final boolean hasLiteralRight = call.getOperands().get(3).getKind().equals(SqlKind.LITERAL); + if (!(childRel instanceof HiveTableScan)) { + return computeFunctionSelectivity(call); + } - if (childRel instanceof HiveTableScan && hasLiteralBool && hasInputRef && hasLiteralLeft && hasLiteralRight) { - final HiveTableScan t = (HiveTableScan) childRel; - final int inputRefIndex = ((RexInputRef) call.getOperands().get(1)).getIndex(); - final List colStats = t.getColStat(Collections.singletonList(inputRefIndex)); + List operands = call.getOperands(); + final boolean hasLiteralBool = operands.get(0).getKind().equals(SqlKind.LITERAL); + Optional leftLiteral = extractLiteral(operands.get(2)); + Optional rightLiteral = extractLiteral(operands.get(3)); + + if (hasLiteralBool && leftLiteral.isPresent() && rightLiteral.isPresent()) { + final HiveTableScan scan = (HiveTableScan) childRel; + float leftValue = leftLiteral.get(); + float rightValue = rightLiteral.get(); + + final Object inverseBoolValueObject = ((RexLiteral) operands.getFirst()).getValue(); + boolean inverseBool = Boolean.parseBoolean(inverseBoolValueObject.toString()); + // when they are equal it's an equality predicate, we cannot handle it as "BETWEEN" + if (Objects.equals(leftValue, rightValue)) { + return inverseBool ? computeNotEqualitySelectivity(call) : computeFunctionSelectivity(call); + } + Range rangeBoundaries = makeRange(leftValue, BoundType.CLOSED, rightValue, BoundType.CLOSED); + Range typeBoundaries = inverseBool ? Range.closed(Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY) : null; + + RexNode expr = operands.get(1); // expr to be checked by the BETWEEN + if (isRemovableCast(expr, scan)) { + typeBoundaries = + getRangeOfDecimalType(expr.getType(), rangeBoundaries.lowerBoundType(), rangeBoundaries.upperBoundType()); + rangeBoundaries = adjustRangeToDecimalType(rangeBoundaries, expr.getType(), typeBoundaries); + expr = RexUtil.removeCast(expr); + } + + int inputRefIndex = -1; + if (expr.getKind().equals(SqlKind.INPUT_REF)) { + inputRefIndex = ((RexInputRef) expr).getIndex(); + } + + if (inputRefIndex < 0) { + return computeFunctionSelectivity(call); + } + + final List colStats = scan.getColStat(Collections.singletonList(inputRefIndex)); if (!colStats.isEmpty() && isHistogramAvailable(colStats.get(0))) { final KllFloatsSketch kll = KllFloatsSketch.heapify(Memory.wrap(colStats.get(0).getHistogram())); - final SqlTypeName typeName = call.getOperands().get(1).getType().getSqlTypeName(); - final Object inverseBoolValueObject = ((RexLiteral) call.getOperands().get(0)).getValue(); - boolean inverseBool = Boolean.parseBoolean(inverseBoolValueObject.toString()); - final Object leftBoundValueObject = ((RexLiteral) call.getOperands().get(2)).getValue(); - float leftValue = extractLiteral(typeName, leftBoundValueObject); - final Object rightBoundValueObject = ((RexLiteral) call.getOperands().get(3)).getValue(); - float rightValue = extractLiteral(typeName, rightBoundValueObject); - // when inverseBool == true, this is a NOT_BETWEEN and selectivity must be inverted + double rawSelectivity = rangedSelectivity(kll, rangeBoundaries); if (inverseBool) { - if (rightValue == leftValue) { - return computeNotEqualitySelectivity(call); - } else if (rightValue < leftValue) { - return 1.0; - } - return 1.0 - (kll.getN() * betweenSelectivity(kll, leftValue, rightValue) / t.getTable().getRowCount()); - } - // when they are equal it's an equality predicate, we cannot handle it as "between" - if (Double.compare(leftValue, rightValue) != 0) { - return kll.getN() * betweenSelectivity(kll, leftValue, rightValue) / t.getTable().getRowCount(); + // when inverseBool == true, this is a NOT_BETWEEN and selectivity must be inverted + // if there's a cast, the inversion is with respect to its codomain (range of the values of the cast) + double typeRangeSelectivity = rangedSelectivity(kll, typeBoundaries); + rawSelectivity = typeRangeSelectivity - rawSelectivity; } + return scaleSelectivityToNullableValues(kll, rawSelectivity, scan); } } return computeFunctionSelectivity(call); } - private float extractLiteral(SqlTypeName typeName, Object boundValueObject) { + private Optional extractLiteral(RexNode node) { + if (node.getKind() != SqlKind.LITERAL) { + return Optional.empty(); + } + RexLiteral literal = (RexLiteral) node; + if (literal.getValue() == null) { + return Optional.empty(); + } + return extractLiteral(literal.getTypeName(), literal.getValue()); + } + + private Optional extractLiteral(SqlTypeName typeName, Object boundValueObject) { final String boundValueString = boundValueObject.toString(); float value; @@ -299,10 +523,9 @@ private float extractLiteral(SqlTypeName typeName, Object boundValueObject) { value = ((GregorianCalendar) boundValueObject).toInstant().getEpochSecond(); break; default: - throw new IllegalStateException( - "Unsupported type for comparator selectivity evaluation using histogram: " + typeName); + return Optional.empty(); } - return value; + return Optional.of(value); } /** @@ -470,7 +693,7 @@ private boolean isPartitionPredicate(RexNode expr, RelNode r) { } else if (r instanceof Filter) { return isPartitionPredicate(expr, ((Filter) r).getInput()); } else if (r instanceof HiveTableScan) { - RelOptHiveTable table = (RelOptHiveTable) ((HiveTableScan) r).getTable(); + RelOptHiveTable table = (RelOptHiveTable) r.getTable(); ImmutableBitSet cols = RelOptUtil.InputFinder.bits(expr); return table.containsPartitionColumnsOnly(cols); } @@ -489,7 +712,33 @@ public Double visitLiteral(RexLiteral literal) { return null; } - private static double rangedSelectivity(KllFloatsSketch kll, float val1, float val2) { + /** + * Returns the selectivity of a predicate "val1 <= column < val2". + * @param kll the sketch + * @param boundaries the boundaries + * @return the selectivity of "val1 <= column < val2" + */ + private static double rangedSelectivity(KllFloatsSketch kll, Range boundaries) { + // convert the condition to a range val1 <= x < val2 + float newLower = BoundType.CLOSED.equals(boundaries.lowerBoundType()) ? boundaries.lowerEndpoint() + : Math.nextUp(boundaries.lowerEndpoint()); + float newUpper = BoundType.OPEN.equals(boundaries.upperBoundType()) ? boundaries.upperEndpoint() + : Math.nextUp(boundaries.upperEndpoint()); + Range closedOpen = Range.closedOpen(newLower, newUpper); + return rangedSelectivity(kll, closedOpen.lowerEndpoint(), closedOpen.upperEndpoint()); + } + + /** + * Returns the selectivity of a predicate "val1 <= column < val2". + * @param kll the sketch + * @param val1 lower bound (inclusive) + * @param val2 upper bound (exclusive) + * @return the selectivity of "val1 <= column < val2" + */ + static double rangedSelectivity(KllFloatsSketch kll, float val1, float val2) { + if (val1 >= val2) { + return 0; + } float[] splitPoints = new float[] { val1, val2 }; double[] boundaries = kll.getCDF(splitPoints, QuantileSearchCriteria.EXCLUSIVE); return boundaries[1] - boundaries[0]; @@ -574,7 +823,7 @@ public static double betweenSelectivity(KllFloatsSketch kll, float leftValue, fl "Selectivity for BETWEEN leftValue AND rightValue when the two values coincide is not supported, found: " + "leftValue = " + leftValue + " and rightValue = " + rightValue); } - return rangedSelectivity(kll, Math.nextDown(leftValue), Math.nextUp(rightValue)); + return rangedSelectivity(kll, leftValue, Math.nextUp(rightValue)); } /** diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestFilterSelectivityEstimator.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestFilterSelectivityEstimator.java index 4255c756e078..3c6b9098a0cf 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestFilterSelectivityEstimator.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestFilterSelectivityEstimator.java @@ -17,7 +17,6 @@ */ package org.apache.hadoop.hive.ql.optimizer.calcite.stats; -import com.google.common.collect.ImmutableList; import org.apache.calcite.jdbc.JavaTypeFactoryImpl; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptPlanner; @@ -27,7 +26,10 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexUtil; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.tools.RelBuilder; @@ -43,6 +45,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan; import org.apache.hadoop.hive.ql.parse.CalcitePlanner; import org.apache.hadoop.hive.ql.plan.ColStatistics; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; @@ -51,24 +54,66 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalTime; +import java.time.ZoneOffset; import java.util.Collections; +import static org.apache.calcite.sql.type.SqlTypeName.BIGINT; +import static org.apache.calcite.sql.type.SqlTypeName.DOUBLE; +import static org.apache.calcite.sql.type.SqlTypeName.FLOAT; +import static org.apache.calcite.sql.type.SqlTypeName.INTEGER; +import static org.apache.calcite.sql.type.SqlTypeName.SMALLINT; +import static org.apache.calcite.sql.type.SqlTypeName.TINYINT; import static org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.betweenSelectivity; import static org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.greaterThanOrEqualSelectivity; import static org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.greaterThanSelectivity; import static org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.isHistogramAvailable; import static org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.lessThanOrEqualSelectivity; import static org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.lessThanSelectivity; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) public class TestFilterSelectivityEstimator { private static final float[] VALUES = { 1, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5, 6, 7 }; + private static final float[] VALUES2 = { + // rounding for DECIMAL(3,1) + // -99.95f and its two predecessors and successors + -99.95001f, -99.950005f, -99.95f, -99.94999f, -99.94998f, + // some values + 0f, 1f, 10f, + // rounding for DECIMAL(3,1) + // 99.95f and its two predecessors and successors + 99.94998f, 99.94999f, 99.95f, 99.950005f, 99.95001f, + // 100f and its two predecessors and successors + 99.999985f, 99.99999f, 100f, 100.00001f, 100.000015f, + // 100.05f and its two predecessors and successors + 100.04999f, 100.049995f, 100.05f, 100.05001f, 100.05002f, + // some values + 1_000f, 10_000f, 100_000f, 1_000_000f, 10_000_000f }; + + /** + * Both dates and timestamps are converted to epoch seconds. + *

+ * See {@link org.apache.hadoop.hive.ql.udf.generic.GenericUDFToUnixTimeStamp#evaluate(GenericUDF.DeferredObject[])}. + */ + private static final float[] VALUES_TIME = { + timestamp("2020-11-01"), timestamp("2020-11-02"), timestamp("2020-11-03"), timestamp("2020-11-04"), + timestamp("2020-11-05T11:23:45Z"), timestamp("2020-11-06"), timestamp("2020-11-07") }; + private static final KllFloatsSketch KLL = StatisticsTestUtils.createKll(VALUES); - private static final float DELTA = Float.MIN_VALUE; + private static final KllFloatsSketch KLL2 = StatisticsTestUtils.createKll(VALUES2); + private static final KllFloatsSketch KLL_TIME = StatisticsTestUtils.createKll(VALUES_TIME); + private static final float DELTA = 1e-7f; private static final RexBuilder REX_BUILDER = new RexBuilder(new JavaTypeFactoryImpl(new HiveTypeSystemImpl())); private static final RelDataTypeFactory TYPE_FACTORY = REX_BUILDER.getTypeFactory(); + private static RelOptCluster relOptCluster; private static RexNode intMinus1; private static RexNode int0; @@ -85,7 +130,6 @@ public class TestFilterSelectivityEstimator { private static RexNode inputRef0; private static RexNode boolFalse; private static RexNode boolTrue; - private static ColStatistics stats; @Mock private RelOptSchema schemaMock; @@ -94,12 +138,14 @@ public class TestFilterSelectivityEstimator { @Mock private RelMetadataQuery mq; - private HiveTableScan tableScan; + private ColStatistics stats; private RelNode scan; + private RexNode currentInputRef; + private int currentValuesSize; @BeforeClass public static void beforeClass() { - RelDataType integerType = TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER); + RelDataType integerType = TYPE_FACTORY.createSqlType(INTEGER); intMinus1 = REX_BUILDER.makeLiteral(-1, integerType, true); int0 = REX_BUILDER.makeLiteral(0, integerType, true); int1 = REX_BUILDER.makeLiteral(1, integerType, true); @@ -113,25 +159,54 @@ public static void beforeClass() { int11 = REX_BUILDER.makeLiteral(11, integerType, true); boolFalse = REX_BUILDER.makeLiteral(false, TYPE_FACTORY.createSqlType(SqlTypeName.BOOLEAN), true); boolTrue = REX_BUILDER.makeLiteral(true, TYPE_FACTORY.createSqlType(SqlTypeName.BOOLEAN), true); - tableType = TYPE_FACTORY.createStructType(ImmutableList.of(integerType), ImmutableList.of("f1")); + RelDataTypeFactory.Builder b = new RelDataTypeFactory.Builder(TYPE_FACTORY); + b.add("f_numeric", decimalType(38, 25)); + b.add("f_timestamp", SqlTypeName.TIMESTAMP); + b.add("f_date", SqlTypeName.DATE).build(); + tableType = b.build(); RelOptPlanner planner = CalcitePlanner.createPlanner(new HiveConf()); relOptCluster = RelOptCluster.create(planner, REX_BUILDER); + } - stats = new ColStatistics(); - stats.setHistogram(KLL.toByteArray()); + private static ColStatistics.Range rangeOf(float[] values) { + float min = Float.MAX_VALUE, max = -Float.MAX_VALUE; + for (float v : values) { + min = Math.min(min, v); + max = Math.max(max, v); + } + return new ColStatistics.Range(min, max); } @Before public void before() { + currentValuesSize = VALUES.length; doReturn(tableType).when(tableMock).getRowType(); - doReturn((double) VALUES.length).when(tableMock).getRowCount(); + when(tableMock.getRowCount()).thenAnswer(a -> (double) currentValuesSize); RelBuilder relBuilder = HiveRelFactories.HIVE_BUILDER.create(relOptCluster, schemaMock); - tableScan = new HiveTableScan(relOptCluster, relOptCluster.traitSetOf(HiveRelNode.CONVENTION), - tableMock, "table", null, false, false); + HiveTableScan tableScan = + new HiveTableScan(relOptCluster, relOptCluster.traitSetOf(HiveRelNode.CONVENTION), tableMock, "table", null, + false, false); scan = relBuilder.push(tableScan).build(); inputRef0 = REX_BUILDER.makeInputRef(scan, 0); + currentInputRef = inputRef0; + + stats = new ColStatistics(); + stats.setHistogram(KLL.toByteArray()); + stats.setRange(rangeOf(VALUES)); + } + + /** + * Note: call this method only at the beginning of a test method. + */ + private void useFieldWithValues(String fieldname, float[] values, KllFloatsSketch sketch) { + currentValuesSize = values.length; + stats.setHistogram(sketch.toByteArray()); + stats.setRange(rangeOf(values)); + int fieldIndex = scan.getRowType().getFieldNames().indexOf(fieldname); + currentInputRef = REX_BUILDER.makeInputRef(scan, fieldIndex); + doReturn(Collections.singletonList(stats)).when(tableMock).getColStat(Collections.singletonList(fieldIndex)); } @Test @@ -420,7 +495,7 @@ public void testComputeRangePredicateSelectivityBetweenLeftLowerThanRight() { @Test public void testComputeRangePredicateSelectivityBetweenLeftEqualsRight() { - doReturn(Collections.singletonList(stats)).when(tableMock).getColStat(Collections.singletonList(0)); + verify(tableMock, never()).getColStat(any()); doReturn(10.0).when(mq).getDistinctRowCount(scan, ImmutableBitSet.of(0), REX_BUILDER.makeLiteral(true)); RexNode filter = REX_BUILDER.makeCall(HiveBetween.INSTANCE, boolFalse, inputRef0, int3, int3); FilterSelectivityEstimator estimator = new FilterSelectivityEstimator(scan, mq); @@ -454,7 +529,7 @@ public void testComputeRangePredicateSelectivityNotBetweenRightLowerThanLeft() { @Test public void testComputeRangePredicateSelectivityNotBetweenLeftEqualsRight() { - doReturn(Collections.singletonList(stats)).when(tableMock).getColStat(Collections.singletonList(0)); + verify(tableMock, never()).getColStat(any()); RexNode filter = REX_BUILDER.makeCall(HiveBetween.INSTANCE, boolTrue, inputRef0, int3, int3); FilterSelectivityEstimator estimator = new FilterSelectivityEstimator(scan, mq); Assert.assertEquals(1, estimator.estimateSelectivity(filter), DELTA); @@ -511,6 +586,271 @@ public void testComputeRangePredicateSelectivityNotBetweenWithNULLS() { doReturn(Collections.singletonList(stats)).when(tableMock).getColStat(Collections.singletonList(0)); RexNode filter = REX_BUILDER.makeCall(HiveBetween.INSTANCE, boolTrue, inputRef0, int1, int3); FilterSelectivityEstimator estimator = new FilterSelectivityEstimator(scan, mq); - Assert.assertEquals(0.55, estimator.estimateSelectivity(filter), DELTA); + // only the values 4, 5, 6, 7 fulfill the condition NOT BETWEEN 1 AND 3 + // (the NULL values do not fulfill the condition) + Assert.assertEquals(0.2, estimator.estimateSelectivity(filter), DELTA); + } + + @Test + public void testRangePredicateWithCast() { + useFieldWithValues("f_numeric", VALUES, KLL); + checkSelectivity(3 / 13.f, ge(cast("f_numeric", TINYINT), int5)); + checkSelectivity(10 / 13.f, lt(cast("f_numeric", TINYINT), int5)); + checkSelectivity(2 / 13.f, gt(cast("f_numeric", TINYINT), int5)); + checkSelectivity(11 / 13.f, le(cast("f_numeric", TINYINT), int5)); + + checkSelectivity(12 / 13f, ge(cast("f_numeric", TINYINT), int2)); + checkSelectivity(1 / 13f, lt(cast("f_numeric", TINYINT), int2)); + checkSelectivity(5 / 13f, gt(cast("f_numeric", TINYINT), int2)); + checkSelectivity(8 / 13f, le(cast("f_numeric", TINYINT), int2)); + + // check some types + checkSelectivity(3 / 13.f, ge(cast("f_numeric", INTEGER), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_numeric", SMALLINT), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_numeric", BIGINT), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_numeric", FLOAT), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_numeric", DOUBLE), int5)); + } + + @Test + public void testRangePredicateWithCast2() { + useFieldWithValues("f_numeric", VALUES2, KLL2); + RelDataType decimal3s1 = decimalType(3, 1); + checkSelectivity(4 / 28.f, ge(cast("f_numeric", decimal3s1), literalFloat(1))); + + // values from -99.94999 to 99.94999 (both inclusive) + checkSelectivity(7 / 28.f, lt(cast("f_numeric", decimal3s1), literalFloat(100))); + checkSelectivity(7 / 28.f, le(cast("f_numeric", decimal3s1), literalFloat(100))); + checkSelectivity(0 / 28.f, gt(cast("f_numeric", decimal3s1), literalFloat(100))); + checkSelectivity(0 / 28.f, ge(cast("f_numeric", decimal3s1), literalFloat(100))); + + RelDataType decimal4s1 = decimalType(4, 1); + checkSelectivity(10 / 28.f, lt(cast("f_numeric", decimal4s1), literalFloat(100))); + checkSelectivity(20 / 28.f, le(cast("f_numeric", decimal4s1), literalFloat(100))); + checkSelectivity(3 / 28.f, gt(cast("f_numeric", decimal4s1), literalFloat(100))); + checkSelectivity(13 / 28.f, ge(cast("f_numeric", decimal4s1), literalFloat(100))); + + RelDataType decimal2s1 = decimalType(2, 1); + checkSelectivity(2 / 28.f, lt(cast("f_numeric", decimal2s1), literalFloat(100))); + checkSelectivity(2 / 28.f, le(cast("f_numeric", decimal2s1), literalFloat(100))); + checkSelectivity(0 / 28.f, gt(cast("f_numeric", decimal2s1), literalFloat(100))); + checkSelectivity(0 / 28.f, ge(cast("f_numeric", decimal2s1), literalFloat(100))); + + // expected: 100_000f + RelDataType decimal7s1 = decimalType(7, 1); + checkSelectivity(1 / 28.f, gt(cast("f_numeric", decimal7s1), literalFloat(10000))); + + // expected: 10_000f, 100_000f, because CAST(1_000_000 AS DECIMAL(7,1)) = NULL, and similar for even larger values + checkSelectivity(2 / 28.f, ge(cast("f_numeric", decimal7s1), literalFloat(9999))); + checkSelectivity(2 / 28.f, ge(cast("f_numeric", decimal7s1), literalFloat(10000))); + + // expected: 100_000f + checkSelectivity(1 / 28.f, gt(cast("f_numeric", decimal7s1), literalFloat(10000))); + checkSelectivity(1 / 28.f, gt(cast("f_numeric", decimal7s1), literalFloat(10001))); + + // expected 1f, 10f, 99.94998f, 99.94999f + checkSelectivity(4 / 28.f, ge(cast("f_numeric", decimal3s1), literalFloat(1))); + checkSelectivity(3 / 28.f, gt(cast("f_numeric", decimal3s1), literalFloat(1))); + // expected -99.94999f, -99.94998f, 0f, 1f + checkSelectivity(4 / 28.f, le(cast("f_numeric", decimal3s1), literalFloat(1))); + checkSelectivity(3 / 28.f, lt(cast("f_numeric", decimal3s1), literalFloat(1))); + + // the cast would apply a modulo operation to the values outside the range of the cast + // so instead a default selectivity should be returned + checkSelectivity(1 / 3.f, lt(cast("f_numeric", TINYINT), literalFloat(100))); + checkSelectivity(1 / 3.f, lt(cast("f_numeric", TINYINT), literalFloat(100))); + } + + private void checkTimeFieldOnMidnightTimestamps(RexNode field) { + // note: use only values from VALUES_TIME that specify a date without hh:mm:ss! + checkSelectivity(7 / 7.f, ge(field, literalTimestamp("2020-11-01"))); + checkSelectivity(5 / 7.f, ge(field, literalTimestamp("2020-11-03"))); + checkSelectivity(1 / 7.f, ge(field, literalTimestamp("2020-11-07"))); + + checkSelectivity(6 / 7.f, gt(field, literalTimestamp("2020-11-01"))); + checkSelectivity(4 / 7.f, gt(field, literalTimestamp("2020-11-03"))); + checkSelectivity(0 / 7.f, gt(field, literalTimestamp("2020-11-07"))); + + checkSelectivity(1 / 7.f, le(field, literalTimestamp("2020-11-01"))); + checkSelectivity(3 / 7.f, le(field, literalTimestamp("2020-11-03"))); + checkSelectivity(7 / 7.f, le(field, literalTimestamp("2020-11-07"))); + + checkSelectivity(0 / 7.f, lt(field, literalTimestamp("2020-11-01"))); + checkSelectivity(2 / 7.f, lt(field, literalTimestamp("2020-11-03"))); + checkSelectivity(6 / 7.f, lt(field, literalTimestamp("2020-11-07"))); + } + + private void checkTimeFieldOnIntraDayTimestamps(RexNode field) { + checkSelectivity(3 / 7.f, ge(field, literalTimestamp("2020-11-05T11:23:45Z"))); + checkSelectivity(2 / 7.f, gt(field, literalTimestamp("2020-11-05T11:23:45Z"))); + checkSelectivity(5 / 7.f, le(field, literalTimestamp("2020-11-05T11:23:45Z"))); + checkSelectivity(4 / 7.f, lt(field, literalTimestamp("2020-11-05T11:23:45Z"))); + } + + @Test + public void testRangePredicateOnTimestamp() { + useFieldWithValues("f_timestamp", VALUES_TIME, KLL_TIME); + checkTimeFieldOnMidnightTimestamps(currentInputRef); + checkTimeFieldOnIntraDayTimestamps(currentInputRef); + } + + @Test + public void testRangePredicateOnTimestampWithCast() { + useFieldWithValues("f_timestamp", VALUES_TIME, KLL_TIME); + RexNode expr1 = cast("f_timestamp", SqlTypeName.DATE); + checkTimeFieldOnMidnightTimestamps(expr1); + checkTimeFieldOnIntraDayTimestamps(expr1); + + RexNode expr2 = cast("f_timestamp", SqlTypeName.TIMESTAMP); + checkTimeFieldOnMidnightTimestamps(expr2); + checkTimeFieldOnIntraDayTimestamps(expr2); + } + + @Test + public void testRangePredicateOnDate() { + useFieldWithValues("f_date", VALUES_TIME, KLL_TIME); + checkTimeFieldOnMidnightTimestamps(currentInputRef); + + // it does not make sense to compare with "2020-11-05T11:23:45Z", + // as that value would not be stored as-is in a date column, but as "2020-11-05" instead + } + + @Test + public void testRangePredicateOnDateWithCast() { + useFieldWithValues("f_date", VALUES_TIME, KLL_TIME); + checkTimeFieldOnMidnightTimestamps(cast("f_date", SqlTypeName.DATE)); + checkTimeFieldOnMidnightTimestamps(cast("f_date", SqlTypeName.TIMESTAMP)); + + // it does not make sense to compare with "2020-11-05T11:23:45Z", + // as that value would not be stored as-is in a date column, but as "2020-11-05" instead + } + + @Test + public void testBetweenWithCastDecimal2s1() { + useFieldWithValues("f_numeric", VALUES2, KLL2); + float total = VALUES2.length; + float universe = 2; // the number of values that "survive" the cast + RexNode cast = REX_BUILDER.makeCast(decimalType(2, 1), inputRef0); + checkBetweenSelectivity(0, universe, total, cast, 100f, 1000f); + checkBetweenSelectivity(1, universe, total, cast, 1f, 100f); + checkBetweenSelectivity(0, universe, total, cast, 100f, 0f); + } + + @Test + public void testBetweenWithCastDecimal3s1() { + useFieldWithValues("f_numeric", VALUES2, KLL2); + float total = VALUES2.length; + float universe = 7; // the number of values that "survive" the cast + RexNode cast = REX_BUILDER.makeCast(decimalType(3, 1), inputRef0); + checkBetweenSelectivity(0, universe, total, cast, 100f, 1000f); + checkBetweenSelectivity(4, universe, total, cast, 1f, 100f); + checkBetweenSelectivity(0, universe, total, cast, 100f, 0f); + } + + @Test + public void testBetweenWithCastDecimal4s1() { + useFieldWithValues("f_numeric", VALUES2, KLL2); + float total = VALUES2.length; + float universe = 23; // the number of values that "survive" the cast + RexNode cast = REX_BUILDER.makeCast(decimalType(4, 1), inputRef0); + // the values between -999.94999... and 999.94999... (both inclusive) pass through the cast + // the values between 99.95 and 100 are rounded up to 100, so they fulfill the BETWEEN + checkBetweenSelectivity(13, universe, total, cast, 100, 1000); + checkBetweenSelectivity(14, universe, total, cast, 1f, 100f); + checkBetweenSelectivity(0, universe, total, cast, 100f, 0f); + } + + @Test + public void testBetweenWithCastDecimal7s1() { + useFieldWithValues("f_numeric", VALUES2, KLL2); + float total = VALUES2.length; + float universe = 26; // the number of values that "survive" the cast + RexNode cast = REX_BUILDER.makeCast(decimalType(7, 1), inputRef0); + checkBetweenSelectivity(14, universe, total, cast, 100, 1000); + checkBetweenSelectivity(14, universe, total, cast, 1f, 100f); + checkBetweenSelectivity(0, universe, total, cast, 100f, 0f); + } + + private void checkSelectivity(float expectedSelectivity, RexNode filter) { + FilterSelectivityEstimator estimator = new FilterSelectivityEstimator(scan, mq); + Assert.assertEquals(filter.toString(), expectedSelectivity, estimator.estimateSelectivity(filter), DELTA); + + // convert "col OP value" to "value INVERSE_OP col", and check it + RexNode inverted = RexUtil.invert(REX_BUILDER, (RexCall) filter); + if (inverted != null) { + Assert.assertEquals(filter.toString(), expectedSelectivity, estimator.estimateSelectivity(inverted), DELTA); + } + } + + private void checkBetweenSelectivity(float expectedEntries, float universe, float total, RexNode value, float lower, + float upper) { + RexNode betweenFilter = + REX_BUILDER.makeCall(HiveBetween.INSTANCE, boolFalse, value, literalFloat(lower), literalFloat(upper)); + FilterSelectivityEstimator estimator = new FilterSelectivityEstimator(scan, mq); + String between = "BETWEEN " + lower + " AND " + upper; + float expectedSelectivity = expectedEntries / total; + String message = between + ": calcite filter " + betweenFilter.toString(); + Assert.assertEquals(message, expectedSelectivity, estimator.estimateSelectivity(betweenFilter), DELTA); + + // invert the filter to a NOT BETWEEN + RexNode invBetween = + REX_BUILDER.makeCall(HiveBetween.INSTANCE, boolTrue, value, literalFloat(lower), literalFloat(upper)); + String invMessage = "NOT " + between + ": calcite filter " + invBetween.toString(); + float invExpectedSelectivity = (universe - expectedEntries) / total; + Assert.assertEquals(invMessage, invExpectedSelectivity, estimator.estimateSelectivity(invBetween), DELTA); + } + + private RexNode cast(String fieldname, SqlTypeName typeName) { + return cast(fieldname, type(typeName)); + } + + private RexNode cast(String fieldname, RelDataType type) { + int fieldIndex = scan.getRowType().getFieldNames().indexOf(fieldname); + RexNode column = REX_BUILDER.makeInputRef(scan, fieldIndex); + return REX_BUILDER.makeCast(type, column); + } + + private RexNode ge(RexNode expr, RexNode value) { + return REX_BUILDER.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, expr, value); + } + + private RexNode gt(RexNode expr, RexNode value) { + return REX_BUILDER.makeCall(SqlStdOperatorTable.GREATER_THAN, expr, value); + } + + private RexNode le(RexNode expr, RexNode value) { + return REX_BUILDER.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, expr, value); + } + + private RexNode lt(RexNode expr, RexNode value) { + return REX_BUILDER.makeCall(SqlStdOperatorTable.LESS_THAN, expr, value); + } + + private static RelDataType type(SqlTypeName typeName) { + return REX_BUILDER.getTypeFactory().createSqlType(typeName); + } + + private static RelDataType decimalType(int precision, int scale) { + return REX_BUILDER.getTypeFactory().createSqlType(SqlTypeName.DECIMAL, precision, scale); + } + + private static RexLiteral literalTimestamp(String timestamp) { + return REX_BUILDER.makeLiteral(timestampMillis(timestamp), + REX_BUILDER.getTypeFactory().createSqlType(SqlTypeName.TIMESTAMP)); + } + + private RexNode literalFloat(float f) { + return REX_BUILDER.makeLiteral(f, type(SqlTypeName.FLOAT)); + } + + private static long timestampMillis(String timestamp) { + if (!timestamp.contains(":")) { + return LocalDate.parse(timestamp).toEpochSecond(LocalTime.MIDNIGHT, ZoneOffset.UTC) * 1000; + } + return Instant.parse(timestamp).toEpochMilli(); + } + + private static long timestamp(String timestamp) { + return timestampMillis(timestamp) / 1000; } }