From bed032d4d86be1067f1c57c7079524c214166b65 Mon Sep 17 00:00:00 2001 From: Konstantin Bereznyakov Date: Thu, 21 May 2026 15:53:07 -0700 Subject: [PATCH 1/5] HIVE-29625: Disambiguate ColStatistics.countDistinct "unknown" from "verified zero" --- .../ql/optimizer/ReduceSinkMapJoinProc.java | 3 +- .../optimizer/SetHashGroupByMinReduction.java | 4 +- .../SortedDynPartitionOptimizer.java | 6 +- .../ql/optimizer/calcite/RelOptHiveTable.java | 5 +- .../stats/HiveRelMdDistinctRowCount.java | 4 + .../calcite/stats/HiveRelMdRowCount.java | 30 +++-- .../annotation/StatsRulesProcFactory.java | 126 +++++++++++++----- .../hadoop/hive/ql/parse/TezCompiler.java | 10 ++ .../hadoop/hive/ql/plan/Statistics.java | 6 +- .../hadoop/hive/ql/stats/StatsUtils.java | 63 ++++++--- .../estimator/PessimisticStatCombiner.java | 4 +- 11 files changed, 191 insertions(+), 70 deletions(-) diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ReduceSinkMapJoinProc.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ReduceSinkMapJoinProc.java index ea192e8af9aa..b1a93ffcaaf1 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ReduceSinkMapJoinProc.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ReduceSinkMapJoinProc.java @@ -187,7 +187,8 @@ public static Object processReduceSinkToHashJoin(ReduceSinkOperator parentRS, Ma ExprNodeDesc realCol = parentRS.getColumnExprMap().get(prefix + "." + keyCol); ColStatistics cs = StatsUtils.getColStatisticsFromExpression(context.conf, stats, realCol); - if (cs == null || cs.getCountDistint() <= 0) { + if (cs == null || cs.getCountDistint() < 0) { + // unknown: same fallback as old "no stats / overloaded NDV=0" path maxKeyCount = Long.MAX_VALUE; break; } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SetHashGroupByMinReduction.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SetHashGroupByMinReduction.java index bbd474b842f8..06a1a5ba8849 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SetHashGroupByMinReduction.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SetHashGroupByMinReduction.java @@ -69,8 +69,8 @@ public Object process(Node nd, Stack stack, Statistics parentStats = groupByOperator.getParentOperators().get(0).getStatistics(); long ndvProduct = StatsUtils.computeNDVGroupingColumns( colStats, parentStats, true); - // if ndvProduct is 0 then column stats state must be partial and we are missing - if (ndvProduct == 0) { + if (ndvProduct < 0) { + // unknown product - same fallback as old "overloaded NDV=0" path return null; } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SortedDynPartitionOptimizer.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SortedDynPartitionOptimizer.java index f5431fa34934..1d3df730c93c 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SortedDynPartitionOptimizer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SortedDynPartitionOptimizer.java @@ -935,7 +935,8 @@ private long computePartCardinality(List partitionPos, for (Integer idx : partitionPos) { ColumnInfo ci = fsParent.getSchema().getSignature().get(idx); ColStatistics partStats = tStats.getColumnStatisticsFromColName(ci.getInternalName()); - if (partStats == null) { + // countDistinct < 0 means "unknown" - same path as missing stats + if (partStats == null || partStats.getCountDistint() < 0) { return -1; } partCardinality *= partStats.getCountDistint(); @@ -950,7 +951,8 @@ private long computePartCardinality(List partitionPos, // implementations on UDFs (e.g. iceberg_bucket reports min(inputNDV, numBuckets)) ColStatistics exprStats = StatsUtils.getColStatisticsFromExpression( this.parseCtx.getConf(), tStats, resolved); - if (exprStats == null) { + // countDistinct < 0 means "unknown" - same path as missing stats + if (exprStats == null || exprStats.getCountDistint() < 0) { return -1; } partCardinality *= exprStats.getCountDistint(); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java index b3e0249b778b..fd440557ddda 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java @@ -588,7 +588,10 @@ private void updateColStats(Set projIndxLst, boolean allowMissingStats) rowCount = 0; hiveColStats = new ArrayList(); for (int i = 0; i < nonPartColNamesThatRqrStats.size(); i++) { - // add empty stats object for each column + // empty stats object for each column: all fields take their Java defaults + // (countDistinct=0, numNulls=0, ...). Under the ColStatistics convention this + // reads as "verified zero" rather than "unknown", which is semantically correct + // here because rowCount is 0 - the table has zero rows after partition pruning. hiveColStats.add( new ColStatistics( nonPartColNamesThatRqrStats.get(i), diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdDistinctRowCount.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdDistinctRowCount.java index b0f40d0d815e..57b03da7a088 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdDistinctRowCount.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdDistinctRowCount.java @@ -60,6 +60,10 @@ public Double getDistinctRowCount(HiveTableScan htRel, RelMetadataQuery mq, Immu List colStats = htRel.getColStat(projIndxLst); Double noDistinctRows = 1.0; for (ColStatistics cStat : colStats) { + // countDistinct < 0 means "unknown" - signal back to Calcite via null + if (cStat.getCountDistint() < 0) { + return null; + } noDistinctRows *= cStat.getCountDistint(); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount.java index 254a5ed8c839..b36ce395840e 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount.java @@ -319,8 +319,11 @@ public static PKFKRelationInfo analyzeJoinForPKFK(Join joinRel, RelMetadataQuery int pkSide = leftIsKey ? 0 : 1; boolean isPKSideSimpleTree = leftIsKey ? SimpleTreeOnJoinKey.check(false, left, lBitSet, mq).left : SimpleTreeOnJoinKey.check(false, right, rBitSet, mq).left; - double leftNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(left, lBitSet, leftPred) : -1; - double rightNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(right, rBitSet, rightPred) : -1; + // getDistinctRowCount returns null when NDV is unknown; box to avoid NPE on unboxing + Double leftNDVBoxed = isPKSideSimpleTree ? mq.getDistinctRowCount(left, lBitSet, leftPred) : null; + Double rightNDVBoxed = isPKSideSimpleTree ? mq.getDistinctRowCount(right, rBitSet, rightPred) : null; + double leftNDV = leftNDVBoxed == null ? -1 : leftNDVBoxed; + double rightNDV = rightNDVBoxed == null ? -1 : rightNDVBoxed; /* * If the ndv of the PK - FK side don't match, and the PK side is a filter @@ -344,8 +347,13 @@ public static PKFKRelationInfo analyzeJoinForPKFK(Join joinRel, RelMetadataQuery * d_date column we can apply the scaling factor. */ double ndvScalingFactor = 1.0; - if ( isPKSideSimpleTree ) { - ndvScalingFactor = pkSide == 0 ? leftNDV/rightNDV : rightNDV / leftNDV; + // denominator must be strictly positive to avoid div-by-zero; the numerator may be 0 + if (isPKSideSimpleTree) { + if (pkSide == 0 && leftNDV >= 0 && rightNDV > 0) { + ndvScalingFactor = leftNDV / rightNDV; + } else if (pkSide != 0 && rightNDV >= 0 && leftNDV > 0) { + ndvScalingFactor = rightNDV / leftNDV; + } } if (pkSide == 0) { @@ -441,8 +449,11 @@ public static Pair constraintsBasedAnalyzeJoinForPKFK rexBuilder, leftFilters, true); RexNode rightPred = RexUtil.composeConjunction( rexBuilder, rightFilters, true); - double leftNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(left, lBitSet, leftPred) : -1; - double rightNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(right, rBitSet, rightPred) : -1; + // getDistinctRowCount returns null when NDV is unknown; box to avoid NPE on unboxing + Double leftNDVBoxed = isPKSideSimpleTree ? mq.getDistinctRowCount(left, lBitSet, leftPred) : null; + Double rightNDVBoxed = isPKSideSimpleTree ? mq.getDistinctRowCount(right, rBitSet, rightPred) : null; + double leftNDV = leftNDVBoxed == null ? -1 : leftNDVBoxed; + double rightNDV = rightNDVBoxed == null ? -1 : rightNDVBoxed; // 5) Add the rest of operators back to the join filters // and create residual condition @@ -459,7 +470,9 @@ public static Pair constraintsBasedAnalyzeJoinForPKFK leftNDV, join.getJoinType().generatesNullsOnRight() ? 1.0 : pkSelectivity); - double ndvScalingFactor = isPKSideSimpleTree ? leftNDV/rightNDV : 1.0; + // denominator must be strictly positive to avoid div-by-zero; the numerator may be 0 + double ndvScalingFactor = (isPKSideSimpleTree && leftNDV >= 0 && rightNDV > 0) + ? leftNDV/rightNDV : 1.0; return Pair.of(new PKFKRelationInfo(1, fkInfo, pkInfo, ndvScalingFactor, isNoFilteringPKSideTree), residualCond); } else { // pkSide == 1 @@ -470,7 +483,8 @@ public static Pair constraintsBasedAnalyzeJoinForPKFK rightNDV, join.getJoinType().generatesNullsOnLeft() ? 1.0 : pkSelectivity); - double ndvScalingFactor = isPKSideSimpleTree ? rightNDV/leftNDV : 1.0; + double ndvScalingFactor = (isPKSideSimpleTree && rightNDV >= 0 && leftNDV > 0) + ? rightNDV/leftNDV : 1.0; return Pair.of(new PKFKRelationInfo(0, fkInfo, pkInfo, ndvScalingFactor, isNoFilteringPKSideTree), residualCond); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java index abfe6170217e..cdabac1b08dc 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java @@ -557,11 +557,18 @@ private long evaluateInExpr(Statistics stats, ExprNodeDesc pred, long currNumRow factor *= children.size() - 1; } for (int i = 0; i < columnStats.size(); i++) { - long dvs = columnStats.get(i) == null ? 0 : columnStats.get(i).getCountDistint(); - if (dvs == 0) { + ColStatistics cs = columnStats.get(i); + long dvs = cs == null ? -1L : cs.getCountDistint(); + if (dvs < 0) { + // missing stats or unknown NDV factor *= 0.5; continue; } + if (dvs == 0) { + // verified zero distinct values: IN cannot match any row + factor = 0; + break; + } // (num of distinct vals for col in IN clause / num of distinct vals for col ) double columnFactor = 1.0 / dvs; if (!multiColumn) { @@ -1318,7 +1325,13 @@ private long evaluateChildExpr(Statistics stats, ExprNodeDesc child, ColStatistics cs = stats.getColumnStatisticsFromColName(colName); if (cs != null) { long dvs = cs.getCountDistint(); - numRows = dvs == 0 ? numRows / 2 : Math.round((double) numRows / dvs); + if (dvs < 0) { + numRows = numRows / 2; // unknown + } else if (dvs == 0) { + numRows = 0; // verified zero distinct values - no rows match + } else { + numRows = Math.round((double) numRows / dvs); + } return numRows; } } else if (leaf instanceof ExprNodeColumnDesc) { @@ -1339,7 +1352,13 @@ private long evaluateChildExpr(Statistics stats, ExprNodeDesc child, ColStatistics cs = stats.getColumnStatisticsFromColName(colName); if (cs != null) { long dvs = cs.getCountDistint(); - numRows = dvs == 0 ? numRows / 2 : Math.round((double) numRows / dvs); + if (dvs < 0) { + numRows = numRows / 2; // unknown + } else if (dvs == 0) { + numRows = 0; // verified zero distinct values - no rows match + } else { + numRows = Math.round((double) numRows / dvs); + } return numRows; } } @@ -1518,14 +1537,13 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, // compute product of distinct values of grouping columns long ndvProduct = StatsUtils.computeNDVGroupingColumns(colStats, parentStats, false); - // if ndvProduct is 0 then column stats state must be partial and we are missing - // column stats for a group by column - if (ndvProduct == 0) { + if (ndvProduct < 0) { + // unknown - missing column stats or unknown NDV on a grouping column ndvProduct = parentNumRows / 2; if (LOG.isDebugEnabled()) { - LOG.debug("STATS-" + gop.toString() + ": ndvProduct became 0 as some column does not" + - " have stats. ndvProduct changed to: " + ndvProduct); + LOG.debug("STATS-" + gop.toString() + ": ndvProduct unknown; falling back to " + + ndvProduct); } } final long maxColumnNDV = colStats.stream() @@ -1720,6 +1738,10 @@ static void computeAggregateColumnMinMax(ColStatistics cs, HiveConf conf, Aggreg long valuesCount = agg.getDistinct() ? parentCS.getCountDistint() : parentStats.getNumRows() - numNulls; + // countDistinct < 0 would produce a Range with a negative maxValue + if (agg.getDistinct() && valuesCount < 0) { + return; + } Range range = parentCS.getRange(); // Get the aggregate function matching the name in the query. GenericUDAFResolver udaf = @@ -1819,9 +1841,24 @@ private boolean checkMapSideAggregation(GroupByOperator gop, // estimate size of key from column statistics long avgKeySize = 0; + // lazily computed on first unknown NDV (null = not yet looked up) + Long parentNumRows = null; for (ColStatistics cs : colStats) { if (cs != null) { - numEstimatedRows = StatsUtils.safeMult(numEstimatedRows, cs.getCountDistint()); + long ndv = cs.getCountDistint(); + if (ndv < 0) { + if (parentNumRows == null) { + // unknown NDV: fall back to parentNumRows / 2, matching the heuristic + // used elsewhere in this file when GROUP BY cardinality cannot be computed + Statistics parentStats = gop.getParentOperators().get(0).getStatistics(); + parentNumRows = (parentStats != null) ? parentStats.getNumRows() : -1L; + } + if (parentNumRows <= 0) { + return false; + } + ndv = parentNumRows / 2; + } + numEstimatedRows = StatsUtils.safeMult(numEstimatedRows, ndv); avgKeySize += Math.ceil(cs.getAvgColLen()); } } @@ -2248,14 +2285,15 @@ private long calculateUnmatchedRowsForOuter(HiveConf conf, long inputRowCount, distinctVal = StatsUtils.addWithExpDecay(distinctVals); } } - // If we have a greater number of unmatched values than number of distinct values, - // we just return the number of rows in the input as we can assume there are no - // matches - if (distinctUnmatched >= distinctVal) { + // distinctVal <= 0 covers unknown (<0) and verified-zero (==0) cases; the latter means + // no key value matches anything, so every input row is unmatched in an outer join. + // distinctUnmatched < 0 (unknown) is treated conservatively the same way. + // If unmatched >= distinctVal, all rows can be assumed unmatched. + if (distinctVal <= 0 || distinctUnmatched < 0 || distinctUnmatched >= distinctVal) { return inputRowCount; } // Otherwise, divide the number of input rows by the number of distinct values - // and divide by the number of distinct values unmatched + // and multiply by the number of distinct values unmatched return StatsUtils.safeMult(inputRowCount / distinctVal, distinctUnmatched); } @@ -2632,26 +2670,29 @@ private void updateColStats(HiveConf conf, Statistics stats, long leftUnmatchedR int pos = jop.getConf().getReversedExprs().get(cs.getColumnName()); long oldDV = cs.getCountDistint(); - boolean useCalciteForNdvReadjustment - = HiveConf.getBoolVar(conf, ConfVars.HIVE_STATS_JOIN_NDV_READJUSTMENT); - long newDV = oldDV; - if (useCalciteForNdvReadjustment) { - Double approxNdv = RelMdUtil.numDistinctVals(oldDV * 1.0, newNumRows * 1.0); - Preconditions.checkNotNull(approxNdv, "approximate NDV is null"); - newDV = approxNdv.longValue(); - } else { - long oldRowCount = rowCountParents.get(pos); - double ratio = (double) newNumRows / (double) oldRowCount; - - // if ratio is greater than 1, then number of rows increases. This can happen - // when some operators like GROUPBY duplicates the input rows in which case - // number of distincts should not change. Update the distinct count only when - // the output number of rows is less than input number of rows. - if (ratio <= 1.0) { - newDV = (long) Math.ceil(ratio * oldDV); + // countDistinct < 0 means "unknown" + if (oldDV >= 0) { + boolean useCalciteForNdvReadjustment + = HiveConf.getBoolVar(conf, ConfVars.HIVE_STATS_JOIN_NDV_READJUSTMENT); + long newDV = oldDV; + if (useCalciteForNdvReadjustment) { + Double approxNdv = RelMdUtil.numDistinctVals(oldDV * 1.0, newNumRows * 1.0); + Preconditions.checkNotNull(approxNdv, "approximate NDV is null"); + newDV = approxNdv.longValue(); + } else { + long oldRowCount = rowCountParents.get(pos); + double ratio = (double) newNumRows / (double) oldRowCount; + + // if ratio is greater than 1, then number of rows increases. This can happen + // when some operators like GROUPBY duplicates the input rows in which case + // number of distincts should not change. Update the distinct count only when + // the output number of rows is less than input number of rows. + if (ratio <= 1.0) { + newDV = (long) Math.ceil(ratio * oldDV); + } } + cs.setCountDistint(newDV); } - cs.setCountDistint(newDV); updateNumNulls(cs, leftUnmatchedRows, rightUnmatchedRows, newNumRows, pos, jop); } stats.setColumnStats(colStats); @@ -2734,7 +2775,12 @@ private long computeRowCountAssumingInnerJoin(List rowCountParents, long d } } - denom = denom == 0 ? 1 : denom; + // denom < 0 (unknown) and denom == 0 (verified-zero join key, cardinality formula + // degenerate) both fall back to "no constraint" rather than producing a negative + // factor or div-by-zero + if (denom <= 0) { + denom = 1; + } factor = (double) max / (double) denom; for (int i = 0; i < rowCountParents.size(); i++) { @@ -2786,6 +2832,12 @@ private long getDenominatorForUnmatchedRows(List distinctVals) { if (distinctVals.isEmpty()) { return 2; } + // any unknown (<0) contributor makes the result unknown + for (Long v : distinctVals) { + if (v < 0) { + return -1L; + } + } // simple join from 2 relations: denom = min(v1, v2) if (distinctVals.size() <= 2) { @@ -2826,6 +2878,12 @@ private long getDenominator(List distinctVals) { // denominator is 2. return 2; } + // any unknown (<0) contributor makes the result unknown + for (Long v : distinctVals) { + if (v < 0) { + return -1L; + } + } // simple join from 2 relations: denom = max(v1, v2) if (distinctVals.size() <= 2) { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java index ede30bfb946d..09a08388126d 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java @@ -1624,6 +1624,12 @@ private static double getProbeDecodeNDVRatio(TableScanOperator tsOp, MapJoinOper tsKeyCardinality = tsStats.getCountDistint(); } } + if (tsKeyCardinality <= 0) { + // verified-zero or negative-sentinel denominator: avoid Infinity/negative ratio. + // Use 1.0 (neutral - broadcast and probe equal) so this MJ doesn't artificially win + // the lowest-ratio selector + return 1.0; + } return mjKeyCardinality / (double) tsKeyCardinality; } @@ -1750,6 +1756,10 @@ private static double getBloomFilterSelectivity( + ", tsKeyCardinality=" + tsKeyCardinality + ", keyDomainCardinality=" + keyDomainCardinality); } // Selectivity: key cardinality of semijoin / domain cardinality + // keyDomainCardinality <= 0 (verified zero or unknown) makes the formula degenerate + if (keyDomainCardinality <= 0) { + return 1.0; + } return selKeyCardinality / (double) keyDomainCardinality; } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/plan/Statistics.java b/ql/src/java/org/apache/hadoop/hive/ql/plan/Statistics.java index d672b7acfc22..ab048a94dc29 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/plan/Statistics.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/plan/Statistics.java @@ -260,7 +260,11 @@ public void addToColumnStats(List colStats) { } else { existing.setNumNulls(StatsUtils.safeAdd(existing.getNumNulls(), cs.getNumNulls())); } - existing.setCountDistint(Math.max(existing.getCountDistint(), cs.getCountDistint())); + if (cs.getCountDistint() < 0 || existing.getCountDistint() < 0) { + existing.setCountDistint(-1); + } else { + existing.setCountDistint(Math.max(existing.getCountDistint(), cs.getCountDistint())); + } } } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/stats/StatsUtils.java b/ql/src/java/org/apache/hadoop/hive/ql/stats/StatsUtils.java index 55f9d0c1e158..68a741b625d2 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/stats/StatsUtils.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/stats/StatsUtils.java @@ -813,7 +813,7 @@ public static ColStatistics getColStatistics(ColumnStatisticsObj cso, String col } else if (colTypeLowerCase.equals(serdeConstants.STRING_TYPE_NAME) || colTypeLowerCase.startsWith(serdeConstants.CHAR_TYPE_NAME) || colTypeLowerCase.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) { - cs.setCountDistint(csd.getStringStats().getNumDVs()); + cs.setCountDistint(csd.getStringStats().isSetNumDVs() ? csd.getStringStats().getNumDVs() : -1); cs.setNumNulls(csd.getStringStats().getNumNulls()); cs.setAvgColLen(csd.getStringStats().getAvgColLen()); cs.setBitVectors(csd.getStringStats().getBitVectors()); @@ -837,9 +837,12 @@ public static ColStatistics getColStatistics(ColumnStatisticsObj cso, String col } else if (colTypeLowerCase.equals(serdeConstants.BINARY_TYPE_NAME)) { cs.setAvgColLen(csd.getBinaryStats().getAvgColLen()); cs.setNumNulls(csd.getBinaryStats().getNumNulls()); + // BinaryColumnStatsData has no numDVs field - the metastore does not track NDV + // for binary columns, so it is genuinely unknown + cs.setCountDistint(-1); } else if (colTypeLowerCase.equals(serdeConstants.TIMESTAMP_TYPE_NAME)) { cs.setAvgColLen(JavaDataModel.get().lengthOfTimestamp()); - cs.setCountDistint(csd.getTimestampStats().getNumDVs()); + cs.setCountDistint(csd.getTimestampStats().isSetNumDVs() ? csd.getTimestampStats().getNumDVs() : -1); cs.setNumNulls(csd.getTimestampStats().getNumNulls()); Long lowVal = (csd.getTimestampStats().getLowValue() != null) ? csd.getTimestampStats().getLowValue() .getSecondsSinceEpoch() : null; @@ -852,7 +855,7 @@ public static ColStatistics getColStatistics(ColumnStatisticsObj cso, String col cs.setAvgColLen(JavaDataModel.get().lengthOfTimestamp()); } else if (colTypeLowerCase.startsWith(serdeConstants.DECIMAL_TYPE_NAME)) { cs.setAvgColLen(JavaDataModel.get().lengthOfDecimal()); - cs.setCountDistint(csd.getDecimalStats().getNumDVs()); + cs.setCountDistint(csd.getDecimalStats().isSetNumDVs() ? csd.getDecimalStats().getNumDVs() : -1); cs.setNumNulls(csd.getDecimalStats().getNumNulls()); Decimal highValue = csd.getDecimalStats().getHighValue(); Decimal lowValue = csd.getDecimalStats().getLowValue(); @@ -871,7 +874,7 @@ public static ColStatistics getColStatistics(ColumnStatisticsObj cso, String col cs.setHistogram(csd.getDecimalStats().getHistogram()); } else if (colTypeLowerCase.equals(serdeConstants.DATE_TYPE_NAME)) { cs.setAvgColLen(JavaDataModel.get().lengthOfDate()); - cs.setCountDistint(csd.getDateStats().getNumDVs()); + cs.setCountDistint(csd.getDateStats().isSetNumDVs() ? csd.getDateStats().getNumDVs() : -1); cs.setNumNulls(csd.getDateStats().getNumNulls()); Long lowVal = (csd.getDateStats().getLowValue() != null) ? csd.getDateStats().getLowValue() .getDaysSinceEpoch() : null; @@ -900,7 +903,7 @@ public static void fillColumnStatisticsData(ColumnStatisticsData data, ColStatis private static void fillColStatisticsFromLongStatsData(ColStatistics cs, LongColumnStatsData longStats, double avgColLen) { - cs.setCountDistint(longStats.getNumDVs()); + cs.setCountDistint(longStats.isSetNumDVs() ? longStats.getNumDVs() : -1); cs.setNumNulls(longStats.getNumNulls()); cs.setAvgColLen(avgColLen); Long lowVal = longStats.isSetLowValue() ? longStats.getLowValue() : null; @@ -912,7 +915,7 @@ private static void fillColStatisticsFromLongStatsData(ColStatistics cs, LongCol private static void fillColStatisticsFromDoubleStatsData(ColStatistics cs, DoubleColumnStatsData doubleStats, double avgColLen) { - cs.setCountDistint(doubleStats.getNumDVs()); + cs.setCountDistint(doubleStats.isSetNumDVs() ? doubleStats.getNumDVs() : -1); cs.setNumNulls(doubleStats.getNumNulls()); cs.setAvgColLen(avgColLen); Double lowVal = doubleStats.isSetLowValue() ? doubleStats.getLowValue() : null; @@ -1690,6 +1693,12 @@ public static Long addWithExpDecay (List distinctVals) { // Exponential back-off for NDVs. // 1) Descending order sort of NDVs // 2) denominator = NDV1 * (NDV2 ^ (1/2)) * (NDV3 ^ (1/4))) * .... + // any unknown (<0) contributor makes the result unknown + for (Long v : distinctVals) { + if (v < 0) { + return -1L; + } + } distinctVals.sort(Collections.reverseOrder()); long denom = distinctVals.get(0); @@ -1716,6 +1725,10 @@ private static long getNDVFor(ExprNodeGenericFuncDesc engfd, long numRows, Stati for (String col : engfd.getCols()) { ColStatistics stats = parentStats.getColumnStatisticsFromColName(col); if (stats != null) { + // countDistinct < 0 means "unknown" + if (stats.getCountDistint() < 0) { + return -1L; + } ndvs.add(stats.getCountDistint()); } } @@ -2036,20 +2049,23 @@ public static void updateStats(Statistics stats, long newNumRows, for (ColStatistics cs : colStats) { long oldDV = cs.getCountDistint(); if (affectedColumns.contains(cs.getColumnName())) { - long newDV = oldDV; - - // if ratio is greater than 1, then number of rows increases. This can happen - // when some operators like GROUPBY duplicates the input rows in which case - // number of distincts should not change. Update the distinct count only when - // the output number of rows is less than input number of rows. - if (ratio <= 1.0) { - newDV = (long) Math.ceil(ratio * oldDV); - } - cs.setCountDistint(newDV); cs.setFilterColumn(); - oldDV = newDV; + // countDistinct < 0 means "unknown" - skip the NDV math + if (oldDV >= 0) { + long newDV = oldDV; + + // if ratio is greater than 1, then number of rows increases. This can happen + // when some operators like GROUPBY duplicates the input rows in which case + // number of distincts should not change. Update the distinct count only when + // the output number of rows is less than input number of rows. + if (ratio <= 1.0) { + newDV = (long) Math.ceil(ratio * oldDV); + } + cs.setCountDistint(newDV); + oldDV = newDV; + } } - if (oldDV > newNumRows) { + if (oldDV >= 0 && oldDV > newNumRows) { cs.setCountDistint(newNumRows); } // numNulls < 0 means "unknown" - preserve the sentinel value @@ -2080,7 +2096,8 @@ public static void scaleColStatistics(List colStats, double facto if (cs.getNumNulls() >= 0) { cs.setNumNulls(StatsUtils.safeMult(cs.getNumNulls(), factor)); } - if (factor < 1.0) { + // countDistinct < 0 means "unknown" - preserve the sentinel value + if (factor < 1.0 && cs.getCountDistint() >= 0) { final double newNDV = Math.ceil(cs.getCountDistint() * factor); cs.setCountDistint(newNDV > Long.MAX_VALUE ? Long.MAX_VALUE : (long) newNDV); } @@ -2092,7 +2109,8 @@ public static long computeNDVGroupingColumns(List colStats, Stati List ndvValues = extractNDVGroupingColumns(colStats, parentStats); if (ndvValues == null) { - return 0L; + // unknown: a grouping column has NDV<0 or stats are missing on a partial state + return -1L; } if (ndvValues.isEmpty()) { // No grouping columns, one row @@ -2112,6 +2130,11 @@ private static List extractNDVGroupingColumns(List colStats for (ColStatistics cs : colStats) { if (cs != null) { long ndv = cs.getCountDistint(); + // countDistinct < 0 means "unknown" - signal it like a missing entry + if (ndv < 0) { + ndvValues = null; + break; + } if (cs.getNumNulls() > 0) { ndv = StatsUtils.safeAdd(ndv, 1); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/stats/estimator/PessimisticStatCombiner.java b/ql/src/java/org/apache/hadoop/hive/ql/stats/estimator/PessimisticStatCombiner.java index 4de2867de7c0..2a3e16c48235 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/stats/estimator/PessimisticStatCombiner.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/stats/estimator/PessimisticStatCombiner.java @@ -41,7 +41,9 @@ public void add(ColStatistics stat) { if (stat.getAvgColLen() > result.getAvgColLen()) { result.setAvgColLen(stat.getAvgColLen()); } - if (stat.getCountDistint() > result.getCountDistint()) { + if (stat.getCountDistint() < 0 || result.getCountDistint() < 0) { + result.setCountDistint(-1); + } else if (stat.getCountDistint() > result.getCountDistint()) { result.setCountDistint(stat.getCountDistint()); } if (stat.getNumNulls() < 0 || result.getNumNulls() < 0) { From 124abe620d63e490f50acb32ec969b7014ffc63d Mon Sep 17 00:00:00 2001 From: Konstantin Bereznyakov Date: Fri, 22 May 2026 14:33:15 -0700 Subject: [PATCH 2/5] HIVE-29625: impacted .out files + reverting an unintended edit --- .../apache/hadoop/hive/ql/parse/TezCompiler.java | 10 ---------- ...pes_non_dictionary_encoding_vectorization.q.out | 14 +++++++------- .../llap/vector_binary_join_groupby.q.out | 2 +- 3 files changed, 8 insertions(+), 18 deletions(-) diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java index 09a08388126d..ede30bfb946d 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java @@ -1624,12 +1624,6 @@ private static double getProbeDecodeNDVRatio(TableScanOperator tsOp, MapJoinOper tsKeyCardinality = tsStats.getCountDistint(); } } - if (tsKeyCardinality <= 0) { - // verified-zero or negative-sentinel denominator: avoid Infinity/negative ratio. - // Use 1.0 (neutral - broadcast and probe equal) so this MJ doesn't artificially win - // the lowest-ratio selector - return 1.0; - } return mjKeyCardinality / (double) tsKeyCardinality; } @@ -1756,10 +1750,6 @@ private static double getBloomFilterSelectivity( + ", tsKeyCardinality=" + tsKeyCardinality + ", keyDomainCardinality=" + keyDomainCardinality); } // Selectivity: key cardinality of semijoin / domain cardinality - // keyDomainCardinality <= 0 (verified zero or unknown) makes the formula degenerate - if (keyDomainCardinality <= 0) { - return 1.0; - } return selKeyCardinality / (double) keyDomainCardinality; } diff --git a/ql/src/test/results/clientpositive/llap/parquet_types_non_dictionary_encoding_vectorization.q.out b/ql/src/test/results/clientpositive/llap/parquet_types_non_dictionary_encoding_vectorization.q.out index d4d9cb53e2b9..0b708705624b 100644 --- a/ql/src/test/results/clientpositive/llap/parquet_types_non_dictionary_encoding_vectorization.q.out +++ b/ql/src/test/results/clientpositive/llap/parquet_types_non_dictionary_encoding_vectorization.q.out @@ -2414,13 +2414,13 @@ STAGE PLANS: minReductionHashAggr: 0.99 mode: hash outputColumnNames: _col0, _col1 - Statistics: Num rows: 1 Data size: 48 Basic stats: COMPLETE Column stats: COMPLETE + Statistics: Num rows: 150 Data size: 1960 Basic stats: COMPLETE Column stats: COMPLETE Reduce Output Operator key expressions: _col0 (type: binary) null sort order: z sort order: + Map-reduce partition columns: _col0 (type: binary) - Statistics: Num rows: 1 Data size: 48 Basic stats: COMPLETE Column stats: COMPLETE + Statistics: Num rows: 150 Data size: 1960 Basic stats: COMPLETE Column stats: COMPLETE value expressions: _col1 (type: bigint) Execution mode: vectorized, llap LLAP IO: all inputs (cache only) @@ -2432,16 +2432,16 @@ STAGE PLANS: keys: KEY._col0 (type: binary) mode: mergepartial outputColumnNames: _col0, _col1 - Statistics: Num rows: 1 Data size: 48 Basic stats: COMPLETE Column stats: COMPLETE + Statistics: Num rows: 75 Data size: 1000 Basic stats: COMPLETE Column stats: COMPLETE Select Operator expressions: hex(_col0) (type: string), _col1 (type: bigint), _col0 (type: binary) outputColumnNames: _col0, _col1, _col2 - Statistics: Num rows: 1 Data size: 232 Basic stats: COMPLETE Column stats: COMPLETE + Statistics: Num rows: 75 Data size: 14800 Basic stats: COMPLETE Column stats: COMPLETE Reduce Output Operator key expressions: _col2 (type: binary) null sort order: z sort order: + - Statistics: Num rows: 1 Data size: 232 Basic stats: COMPLETE Column stats: COMPLETE + Statistics: Num rows: 75 Data size: 14800 Basic stats: COMPLETE Column stats: COMPLETE value expressions: _col0 (type: string), _col1 (type: bigint) Reducer 3 Execution mode: vectorized, llap @@ -2449,10 +2449,10 @@ STAGE PLANS: Select Operator expressions: VALUE._col0 (type: string), VALUE._col1 (type: bigint) outputColumnNames: _col0, _col1 - Statistics: Num rows: 1 Data size: 192 Basic stats: COMPLETE Column stats: COMPLETE + Statistics: Num rows: 75 Data size: 14400 Basic stats: COMPLETE Column stats: COMPLETE File Output Operator compressed: false - Statistics: Num rows: 1 Data size: 192 Basic stats: COMPLETE Column stats: COMPLETE + Statistics: Num rows: 75 Data size: 14400 Basic stats: COMPLETE Column stats: COMPLETE table: input format: org.apache.hadoop.mapred.SequenceFileInputFormat output format: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat diff --git a/ql/src/test/results/clientpositive/llap/vector_binary_join_groupby.q.out b/ql/src/test/results/clientpositive/llap/vector_binary_join_groupby.q.out index 41bc14e5e354..8e8d1d548e0e 100644 --- a/ql/src/test/results/clientpositive/llap/vector_binary_join_groupby.q.out +++ b/ql/src/test/results/clientpositive/llap/vector_binary_join_groupby.q.out @@ -137,7 +137,7 @@ STAGE PLANS: TableScan alias: t1 filterExpr: bin is not null (type: boolean) - probeDecodeDetails: cacheKey:HASH_MAP_MAPJOIN_30_container, bigKeyColName:bin, smallTablePos:1, keyRatio:0.0 + probeDecodeDetails: cacheKey:HASH_MAP_MAPJOIN_30_container, bigKeyColName:bin, smallTablePos:1, keyRatio:100.0 Statistics: Num rows: 100 Data size: 34084 Basic stats: COMPLETE Column stats: COMPLETE TableScan Vectorization: native: true From f5b19e9c38f103a93efff457cfbbba1d283bdf44 Mon Sep 17 00:00:00 2001 From: Konstantin Bereznyakov Date: Mon, 25 May 2026 08:25:00 -0700 Subject: [PATCH 3/5] HIVE-29625: itest code, small tweaks + better code reuse --- .../SortedDynPartitionOptimizer.java | 4 +- .../ql/optimizer/calcite/RelOptHiveTable.java | 5 +- .../stats/HiveRelMdDistinctRowCount.java | 9 +- .../calcite/stats/HiveRelMdRowCount.java | 30 +- .../annotation/StatsRulesProcFactory.java | 55 +-- .../hadoop/hive/ql/stats/StatsUtils.java | 16 +- .../optimizer/TestReduceSinkMapJoinProc.java | 180 +++++++ .../TestSetHashGroupByMinReduction.java | 169 +++++++ .../TestSortedDynPartitionOptimizer.java | 178 +++++++ .../stats/TestHiveRelMdDistinctRowCount.java | 92 ++++ .../annotation/TestStatsRulesProcFactory.java | 441 +++++++++++++++++- .../hadoop/hive/ql/stats/TestStatsUtils.java | 249 ++++++++++ .../TestPessimisticStatCombiner.java | 30 ++ 13 files changed, 1377 insertions(+), 81 deletions(-) create mode 100644 ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestReduceSinkMapJoinProc.java create mode 100644 ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSetHashGroupByMinReduction.java create mode 100644 ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSortedDynPartitionOptimizer.java create mode 100644 ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestHiveRelMdDistinctRowCount.java diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SortedDynPartitionOptimizer.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SortedDynPartitionOptimizer.java index 1d3df730c93c..aaa851b7942b 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SortedDynPartitionOptimizer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SortedDynPartitionOptimizer.java @@ -90,6 +90,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; @@ -924,7 +925,8 @@ private long computeMaxWriters() { * Computes the partition cardinality based on column NDV statistics. * @return positive value = estimated cardinality, 0 = no partition columns, -1 = stats unavailable */ - private long computePartCardinality(List partitionPos, + @VisibleForTesting + long computePartCardinality(List partitionPos, List, ExprNodeDesc>> customPartitionExprs, Statistics tStats, Operator fsParent, ArrayList allRSCols) { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java index fd440557ddda..b3e0249b778b 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/RelOptHiveTable.java @@ -588,10 +588,7 @@ private void updateColStats(Set projIndxLst, boolean allowMissingStats) rowCount = 0; hiveColStats = new ArrayList(); for (int i = 0; i < nonPartColNamesThatRqrStats.size(); i++) { - // empty stats object for each column: all fields take their Java defaults - // (countDistinct=0, numNulls=0, ...). Under the ColStatistics convention this - // reads as "verified zero" rather than "unknown", which is semantically correct - // here because rowCount is 0 - the table has zero rows after partition pruning. + // add empty stats object for each column hiveColStats.add( new ColStatistics( nonPartColNamesThatRqrStats.get(i), diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdDistinctRowCount.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdDistinctRowCount.java index 57b03da7a088..eeb599e848dd 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdDistinctRowCount.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdDistinctRowCount.java @@ -22,6 +22,7 @@ import org.apache.calcite.rel.convert.Converter; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.Spool; +import com.google.common.annotations.VisibleForTesting; import org.apache.calcite.rel.metadata.ReflectiveRelMetadataProvider; import org.apache.calcite.rel.metadata.RelMdDistinctRowCount; import org.apache.calcite.rel.metadata.RelMdUtil; @@ -50,7 +51,8 @@ public class HiveRelMdDistinctRowCount extends RelMdDistinctRowCount { ReflectiveRelMetadataProvider.reflectiveSource( BuiltInMethod.DISTINCT_ROW_COUNT.method, new HiveRelMdDistinctRowCount()); - private HiveRelMdDistinctRowCount() { + @VisibleForTesting + HiveRelMdDistinctRowCount() { } public Double getDistinctRowCount(HiveTableScan htRel, RelMetadataQuery mq, ImmutableBitSet groupKey, @@ -60,9 +62,8 @@ public Double getDistinctRowCount(HiveTableScan htRel, RelMetadataQuery mq, Immu List colStats = htRel.getColStat(projIndxLst); Double noDistinctRows = 1.0; for (ColStatistics cStat : colStats) { - // countDistinct < 0 means "unknown" - signal back to Calcite via null - if (cStat.getCountDistint() < 0) { - return null; + if (cStat.getCountDistint() <= 0) { + return 0.0; } noDistinctRows *= cStat.getCountDistint(); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount.java index b36ce395840e..254a5ed8c839 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount.java @@ -319,11 +319,8 @@ public static PKFKRelationInfo analyzeJoinForPKFK(Join joinRel, RelMetadataQuery int pkSide = leftIsKey ? 0 : 1; boolean isPKSideSimpleTree = leftIsKey ? SimpleTreeOnJoinKey.check(false, left, lBitSet, mq).left : SimpleTreeOnJoinKey.check(false, right, rBitSet, mq).left; - // getDistinctRowCount returns null when NDV is unknown; box to avoid NPE on unboxing - Double leftNDVBoxed = isPKSideSimpleTree ? mq.getDistinctRowCount(left, lBitSet, leftPred) : null; - Double rightNDVBoxed = isPKSideSimpleTree ? mq.getDistinctRowCount(right, rBitSet, rightPred) : null; - double leftNDV = leftNDVBoxed == null ? -1 : leftNDVBoxed; - double rightNDV = rightNDVBoxed == null ? -1 : rightNDVBoxed; + double leftNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(left, lBitSet, leftPred) : -1; + double rightNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(right, rBitSet, rightPred) : -1; /* * If the ndv of the PK - FK side don't match, and the PK side is a filter @@ -347,13 +344,8 @@ public static PKFKRelationInfo analyzeJoinForPKFK(Join joinRel, RelMetadataQuery * d_date column we can apply the scaling factor. */ double ndvScalingFactor = 1.0; - // denominator must be strictly positive to avoid div-by-zero; the numerator may be 0 - if (isPKSideSimpleTree) { - if (pkSide == 0 && leftNDV >= 0 && rightNDV > 0) { - ndvScalingFactor = leftNDV / rightNDV; - } else if (pkSide != 0 && rightNDV >= 0 && leftNDV > 0) { - ndvScalingFactor = rightNDV / leftNDV; - } + if ( isPKSideSimpleTree ) { + ndvScalingFactor = pkSide == 0 ? leftNDV/rightNDV : rightNDV / leftNDV; } if (pkSide == 0) { @@ -449,11 +441,8 @@ public static Pair constraintsBasedAnalyzeJoinForPKFK rexBuilder, leftFilters, true); RexNode rightPred = RexUtil.composeConjunction( rexBuilder, rightFilters, true); - // getDistinctRowCount returns null when NDV is unknown; box to avoid NPE on unboxing - Double leftNDVBoxed = isPKSideSimpleTree ? mq.getDistinctRowCount(left, lBitSet, leftPred) : null; - Double rightNDVBoxed = isPKSideSimpleTree ? mq.getDistinctRowCount(right, rBitSet, rightPred) : null; - double leftNDV = leftNDVBoxed == null ? -1 : leftNDVBoxed; - double rightNDV = rightNDVBoxed == null ? -1 : rightNDVBoxed; + double leftNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(left, lBitSet, leftPred) : -1; + double rightNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(right, rBitSet, rightPred) : -1; // 5) Add the rest of operators back to the join filters // and create residual condition @@ -470,9 +459,7 @@ public static Pair constraintsBasedAnalyzeJoinForPKFK leftNDV, join.getJoinType().generatesNullsOnRight() ? 1.0 : pkSelectivity); - // denominator must be strictly positive to avoid div-by-zero; the numerator may be 0 - double ndvScalingFactor = (isPKSideSimpleTree && leftNDV >= 0 && rightNDV > 0) - ? leftNDV/rightNDV : 1.0; + double ndvScalingFactor = isPKSideSimpleTree ? leftNDV/rightNDV : 1.0; return Pair.of(new PKFKRelationInfo(1, fkInfo, pkInfo, ndvScalingFactor, isNoFilteringPKSideTree), residualCond); } else { // pkSide == 1 @@ -483,8 +470,7 @@ public static Pair constraintsBasedAnalyzeJoinForPKFK rightNDV, join.getJoinType().generatesNullsOnLeft() ? 1.0 : pkSelectivity); - double ndvScalingFactor = (isPKSideSimpleTree && rightNDV >= 0 && leftNDV > 0) - ? rightNDV/leftNDV : 1.0; + double ndvScalingFactor = isPKSideSimpleTree ? rightNDV/leftNDV : 1.0; return Pair.of(new PKFKRelationInfo(0, fkInfo, pkInfo, ndvScalingFactor, isNoFilteringPKSideTree), residualCond); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java index cdabac1b08dc..19985a2fe82c 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java @@ -1324,15 +1324,7 @@ private long evaluateChildExpr(Statistics stats, ExprNodeDesc child, ColStatistics cs = stats.getColumnStatisticsFromColName(colName); if (cs != null) { - long dvs = cs.getCountDistint(); - if (dvs < 0) { - numRows = numRows / 2; // unknown - } else if (dvs == 0) { - numRows = 0; // verified zero distinct values - no rows match - } else { - numRows = Math.round((double) numRows / dvs); - } - return numRows; + return rowsAfterEqualityFilter(numRows, cs.getCountDistint()); } } else if (leaf instanceof ExprNodeColumnDesc) { ExprNodeColumnDesc colDesc = (ExprNodeColumnDesc) leaf; @@ -1351,15 +1343,7 @@ private long evaluateChildExpr(Statistics stats, ExprNodeDesc child, ColStatistics cs = stats.getColumnStatisticsFromColName(colName); if (cs != null) { - long dvs = cs.getCountDistint(); - if (dvs < 0) { - numRows = numRows / 2; // unknown - } else if (dvs == 0) { - numRows = 0; // verified zero distinct values - no rows match - } else { - numRows = Math.round((double) numRows / dvs); - } - return numRows; + return rowsAfterEqualityFilter(numRows, cs.getCountDistint()); } } } @@ -1399,6 +1383,16 @@ private long evaluateChildExpr(Statistics stats, ExprNodeDesc child, return numRows / 2; } + private static long rowsAfterEqualityFilter(long numRows, long dvs) { + if (dvs < 0) { + return numRows / 2; + } + if (dvs == 0) { + return 0; + } + return Math.round((double) numRows / dvs); + } + } /** @@ -2264,7 +2258,8 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, return null; } - private long calculateUnmatchedRowsForOuter(HiveConf conf, long inputRowCount, + @VisibleForTesting + long calculateUnmatchedRowsForOuter(HiveConf conf, long inputRowCount, List joinKeys, Statistics statistics, long distinctUnmatched) { // Extract the ndv from each of the columns involved in the join List distinctVals = new ArrayList<>(); @@ -2642,7 +2637,8 @@ void updateNumNulls(ColStatistics colStats, long leftUnmatchedRows, long rightUn colStats.setNumNulls(newNumNulls); } - private void updateColStats(HiveConf conf, Statistics stats, long leftUnmatchedRows, long rightUnmatchedRows, + @VisibleForTesting + void updateColStats(HiveConf conf, Statistics stats, long leftUnmatchedRows, long rightUnmatchedRows, long newNumRows, CommonJoinOperator jop, Map rowCountParents) { if (newNumRows < 0) { @@ -2687,7 +2683,7 @@ private void updateColStats(HiveConf conf, Statistics stats, long leftUnmatchedR // when some operators like GROUPBY duplicates the input rows in which case // number of distincts should not change. Update the distinct count only when // the output number of rows is less than input number of rows. - if (ratio <= 1.0) { + if (ratio < 1.0) { newDV = (long) Math.ceil(ratio * oldDV); } } @@ -2759,7 +2755,8 @@ private long computeFinalRowCount(List rowCountParents, long interimRowCou return result; } - private long computeRowCountAssumingInnerJoin(List rowCountParents, long denom, + @VisibleForTesting + long computeRowCountAssumingInnerJoin(List rowCountParents, long denom, CommonJoinOperator join) { double factor = 0.0d; long result = 1; @@ -2832,11 +2829,8 @@ private long getDenominatorForUnmatchedRows(List distinctVals) { if (distinctVals.isEmpty()) { return 2; } - // any unknown (<0) contributor makes the result unknown - for (Long v : distinctVals) { - if (v < 0) { - return -1L; - } + if (StatsUtils.containsUnknownNDV(distinctVals)) { + return -1L; } // simple join from 2 relations: denom = min(v1, v2) @@ -2878,11 +2872,8 @@ private long getDenominator(List distinctVals) { // denominator is 2. return 2; } - // any unknown (<0) contributor makes the result unknown - for (Long v : distinctVals) { - if (v < 0) { - return -1L; - } + if (StatsUtils.containsUnknownNDV(distinctVals)) { + return -1L; } // simple join from 2 relations: denom = max(v1, v2) diff --git a/ql/src/java/org/apache/hadoop/hive/ql/stats/StatsUtils.java b/ql/src/java/org/apache/hadoop/hive/ql/stats/StatsUtils.java index 68a741b625d2..cba53f5df4ee 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/stats/StatsUtils.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/stats/StatsUtils.java @@ -1693,11 +1693,8 @@ public static Long addWithExpDecay (List distinctVals) { // Exponential back-off for NDVs. // 1) Descending order sort of NDVs // 2) denominator = NDV1 * (NDV2 ^ (1/2)) * (NDV3 ^ (1/4))) * .... - // any unknown (<0) contributor makes the result unknown - for (Long v : distinctVals) { - if (v < 0) { - return -1L; - } + if (containsUnknownNDV(distinctVals)) { + return -1L; } distinctVals.sort(Collections.reverseOrder()); @@ -2157,4 +2154,13 @@ private static List extractNDVGroupingColumns(List colStats return ndvValues; } + + /** + * Returns true if any value in the given list is the negative NDV "unknown" + * sentinel established by HIVE-29438 / HIVE-29625. Used by aggregators that + * must propagate unknown when any contributor is unknown. + */ + public static boolean containsUnknownNDV(List distinctVals) { + return distinctVals.stream().anyMatch(v -> v < 0); + } } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestReduceSinkMapJoinProc.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestReduceSinkMapJoinProc.java new file mode 100644 index 000000000000..415acfd10eff --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestReduceSinkMapJoinProc.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.optimizer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.Context; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; +import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; +import org.apache.hadoop.hive.ql.exec.Utilities; +import org.apache.hadoop.hive.ql.parse.GenTezProcContext; +import org.apache.hadoop.hive.ql.parse.ParseContext; +import org.apache.hadoop.hive.ql.plan.BaseWork; +import org.apache.hadoop.hive.ql.plan.ColStatistics; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.plan.MapJoinDesc; +import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc; +import org.apache.hadoop.hive.ql.plan.Statistics; +import org.apache.hadoop.hive.ql.stats.StatsUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.MockedStatic; + +class TestReduceSinkMapJoinProc { + + /** + * Exercises the keyCount-from-NDV branch in processReduceSinkToHashJoin (HIVE-29625). + * Master used `cs.getCountDistint() <= 0` to fall back to MAX_VALUE; HIVE-29625 uses + * `cs.getCountDistint() < 0`, so verified-zero NDV no longer cascades to "no info" + * but falls through to `maxKeyCount *= 0` (then clamped to 1 by later logic). + */ + @ParameterizedTest(name = "{0}") + @MethodSource("keyCountFromNdvCases") + void testProcessReduceSinkToHashJoinKeyCountFromNdv( + String scenarioName, long ndv, long parentRows, long expectedKeyCount) throws Exception { + invokeAndAssertKeyCount(buildColStat(ndv), parentRows, expectedKeyCount, scenarioName); + } + + private static Stream keyCountFromNdvCases() { + // Behavior recap: + // Initial keyCount = parentRows = stats.getNumRows() + // For each key col: + // if cs == null or cs.getCountDistint() < 0 -> maxKeyCount = MAX_VALUE; break <-- HIVE-29625 change + // else maxKeyCount *= cs.getCountDistint() + // keyCount = min(maxKeyCount, keyCount) + // if keyCount == 0 -> keyCount = 1 + // joinConf.getParentKeyCounts().put(pos, keyCount) [only if keyCount != MAX_VALUE] + return Stream.of( + // NDV > 0 and below parentRows -> keyCount = NDV + Arguments.of("knownPositiveBelowParent", 10L, 1000L, 10L), + // NDV > 0 but above parentRows -> capped at parentRows + Arguments.of("knownPositiveAboveParent", 5000L, 1000L, 1000L), + // NDV = 0 (verified zero under HIVE-29625) -> maxKeyCount = 0 -> keyCount = 0 -> bumped to 1 + Arguments.of("verifiedZeroBumpsToOne", 0L, 1000L, 1L), + // NDV = -1 (unknown under HIVE-29625) -> maxKeyCount = MAX_VALUE -> keyCount = parentRows + Arguments.of("unknownFallsBackToParent", -1L, 1000L, 1000L) + ); + } + + @Test + void testProcessReduceSinkToHashJoinNullColStatsFallsBackToParent() throws Exception { + // cs == null case (StatsUtils could not derive a stat for the expr) -> same fallback as < 0 + invokeAndAssertKeyCount(null, 1000L, 1000L, + "Null ColStatistics falls back to parent row count (same as NDV < 0)"); + } + + /** + * Shared invocation harness: build a real GenTezProcContext + mocked operator chain, + * stub StatsUtils.getColStatisticsFromExpression to return the given colStat for + * the single key column, run processReduceSinkToHashJoin, and assert the keyCount + * landed in joinConf.getParentKeyCounts() at position 0. + */ + private static void invokeAndAssertKeyCount( + ColStatistics csForKey, long parentRows, long expectedKeyCount, String desc) throws Exception { + + // ---- Operator chain mocks ---- + ReduceSinkOperator parentRS = mock(ReduceSinkOperator.class); + MapJoinOperator mapJoinOp = mock(MapJoinOperator.class); + ReduceSinkDesc rsConf = mock(ReduceSinkDesc.class); + MapJoinDesc joinConf = mock(MapJoinDesc.class); + Statistics rsStats = mock(Statistics.class); + ExprNodeDesc keyExpr = mock(ExprNodeDesc.class); + BaseWork parentWork = mock(BaseWork.class); + + when(parentRS.getConf()).thenReturn(rsConf); + when(parentRS.getStatistics()).thenReturn(rsStats); + Map columnExprMap = new HashMap<>(); + columnExprMap.put(Utilities.ReduceField.KEY.toString() + ".k0", keyExpr); + when(parentRS.getColumnExprMap()).thenReturn(columnExprMap); + + when(mapJoinOp.getConf()).thenReturn(joinConf); + + when(rsConf.getOutputKeyColumnNames()).thenReturn(Arrays.asList("k0")); + + when(joinConf.isBucketMapJoin()).thenReturn(false); + when(joinConf.isDynamicPartitionHashJoin()).thenReturn(false); + Map parentKeyCounts = new LinkedHashMap<>(); + when(joinConf.getParentKeyCounts()).thenReturn(parentKeyCounts); + when(joinConf.getParentToInput()).thenReturn(new LinkedHashMap<>()); + when(joinConf.getParentDataSizes()).thenReturn(new LinkedHashMap<>()); + + when(rsStats.getNumRows()).thenReturn(parentRows); + when(rsStats.getDataSize()).thenReturn(8000L); + + when(parentWork.getName()).thenReturn("parent_work"); + + // ---- Real GenTezProcContext (constructor sets up all the maps for us) ---- + HiveConf conf = new HiveConf(); + ParseContext parseCtx = mock(ParseContext.class); + Context ctx = mock(Context.class); + when(parseCtx.getContext()).thenReturn(ctx); + when(ctx.getSequencer()).thenReturn(new AtomicInteger()); + GenTezProcContext context = new GenTezProcContext( + conf, parseCtx, Collections.emptyList(), new ArrayList<>(), + Collections.emptySet(), Collections.emptySet()); + + context.childToWorkMap.put(parentRS, Arrays.asList(parentWork)); + context.mapJoinParentMap.put(mapJoinOp, Arrays.asList(parentRS)); + + // ---- Stub StatsUtils.getColStatisticsFromExpression to return our chosen colStat ---- + try (MockedStatic stub = mockStatic(StatsUtils.class)) { + stub.when(() -> StatsUtils.getColStatisticsFromExpression( + any(HiveConf.class), any(Statistics.class), any(ExprNodeDesc.class))) + .thenReturn(csForKey); + + try { + ReduceSinkMapJoinProc.processReduceSinkToHashJoin(parentRS, mapJoinOp, context); + } catch (NullPointerException expected) { + // The method continues for ~100 more lines past the keyCount put() that this + // test is verifying - dummy-operator construction, key-table-desc setup, edge + // wiring, etc. - all requiring deep operator-chain mocking irrelevant to the + // HIVE-29625 NDV branch. By the time any NPE fires from that downstream code, + // joinConf.getParentKeyCounts() already received the put() we asserted on + // (line ~229 of ReduceSinkMapJoinProc, well before the first dependency this + // harness doesn't mock). + } + } + + Long actual = parentKeyCounts.get(0); + assertEquals(expectedKeyCount, actual.longValue(), desc); + } + + private static ColStatistics buildColStat(long ndv) { + ColStatistics cs = new ColStatistics("k0", "int"); + cs.setCountDistint(ndv); + return cs; + } +} diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSetHashGroupByMinReduction.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSetHashGroupByMinReduction.java new file mode 100644 index 000000000000..639af98f701d --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSetHashGroupByMinReduction.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.optimizer; + +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyFloat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; + +import org.apache.hadoop.hive.ql.exec.GroupByOperator; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.RowSchema; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.GroupByDesc; +import org.apache.hadoop.hive.ql.plan.GroupByDesc.Mode; +import org.apache.hadoop.hive.ql.plan.Statistics; +import org.apache.hadoop.hive.ql.plan.Statistics.State; +import org.apache.hadoop.hive.ql.stats.StatsUtils; +import org.junit.jupiter.api.Test; +import org.mockito.MockedStatic; + +class TestSetHashGroupByMinReduction { + + @Test + void testProcessReturnsNullWhenNdvProductIsUnknown() throws SemanticException { + GroupByOperator op = setupCompleteHashGroupBy(0.5f, 0.1f); + GroupByDesc desc = op.getConf(); + + try (MockedStatic stub = mockStatic(StatsUtils.class)) { + stub.when(() -> StatsUtils.computeNDVGroupingColumns(any(), any(), eq(true))) + .thenReturn(-1L); + + Object result = new SetHashGroupByMinReduction().process(op, null, null); + + assertNull(result, "Unknown ndvProduct (-1) makes process() return null"); + verify(desc, never()).setMinReductionHashAggr(anyFloat()); + } + } + + @Test + void testProcessProceedsWhenNdvProductIsVerifiedZero() throws SemanticException { + // HIVE-29625 disambiguation: ndvProduct == 0 (verified) now distinct from ndvProduct < 0 (unknown). + // Verified-zero means there are zero distinct group keys -> maximum reduction (factor = 1.0). + GroupByOperator op = setupCompleteHashGroupBy(0.99f, 0.5f); + GroupByDesc desc = op.getConf(); + + try (MockedStatic stub = mockStatic(StatsUtils.class)) { + stub.when(() -> StatsUtils.computeNDVGroupingColumns(any(), any(), eq(true))) + .thenReturn(0L); + + Object result = new SetHashGroupByMinReduction().process(op, null, null); + + assertNull(result, "process() always returns null sentinel"); + // 1f - (0 / numRows) = 1.0; capped by lowerBound 0.5 -> stays 1.0; less than default 0.99? No, 1.0 > 0.99. + // So setMinReductionHashAggr is NOT called (newFactor not strictly less than default). + verify(desc, never()).setMinReductionHashAggr(anyFloat()); + } + } + + @Test + void testProcessProceedsWhenNdvProductIsKnownPositive() throws SemanticException { + // numRows=1000, ndvProduct=100 -> factor = 1 - 100/1000 = 0.9 + // default = 0.99 -> 0.9 < 0.99 -> setMinReductionHashAggr(0.9) + GroupByOperator op = setupCompleteHashGroupBy(0.99f, 0.1f); + GroupByDesc desc = op.getConf(); + + try (MockedStatic stub = mockStatic(StatsUtils.class)) { + stub.when(() -> StatsUtils.computeNDVGroupingColumns(any(), any(), eq(true))) + .thenReturn(100L); + + new SetHashGroupByMinReduction().process(op, null, null); + + verify(desc, atLeastOnce()).setMinReductionHashAggr(anyFloat()); + } + } + + @Test + void testProcessReturnsNullWhenModeNotHash() throws SemanticException { + GroupByOperator op = setupCompleteHashGroupBy(0.99f, 0.5f); + when(op.getConf().getMode()).thenReturn(Mode.MERGEPARTIAL); + + Object result = new SetHashGroupByMinReduction().process(op, null, null); + + assertNull(result, "Non-HASH mode -> early return null"); + verify(op.getConf(), never()).setMinReductionHashAggr(anyFloat()); + } + + @Test + void testProcessReturnsNullWhenBasicStatsIncomplete() throws SemanticException { + GroupByOperator op = setupCompleteHashGroupBy(0.99f, 0.5f); + when(op.getStatistics().getBasicStatsState()).thenReturn(State.PARTIAL); + + Object result = new SetHashGroupByMinReduction().process(op, null, null); + + assertNull(result, "Incomplete basic stats -> early return null"); + verify(op.getConf(), never()).setMinReductionHashAggr(anyFloat()); + } + + @Test + void testProcessReturnsNullWhenColumnStatsIncomplete() throws SemanticException { + GroupByOperator op = setupCompleteHashGroupBy(0.99f, 0.5f); + when(op.getStatistics().getColumnStatsState()).thenReturn(State.PARTIAL); + + Object result = new SetHashGroupByMinReduction().process(op, null, null); + + assertNull(result, "Incomplete column stats -> early return null"); + verify(op.getConf(), never()).setMinReductionHashAggr(anyFloat()); + } + + /** + * Build a GroupByOperator that passes all the early-return gates, with empty keys + * so the colStats loop is a no-op (the inputs to computeNDVGroupingColumns are + * controlled directly via mockStatic in each test). + */ + private static GroupByOperator setupCompleteHashGroupBy( + float defaultMinReduction, float defaultMinReductionLowerBound) { + GroupByOperator op = mock(GroupByOperator.class); + GroupByDesc desc = mock(GroupByDesc.class); + Statistics stats = mock(Statistics.class); + Operator parent = mock(Operator.class); + Statistics parentStats = mock(Statistics.class); + RowSchema schema = mock(RowSchema.class); + + when(op.getConf()).thenReturn(desc); + when(op.getStatistics()).thenReturn(stats); + when(op.getSchema()).thenReturn(schema); + when(schema.getSignature()).thenReturn(Collections.emptyList()); + + when(desc.getMode()).thenReturn(Mode.HASH); + when(desc.getKeys()).thenReturn(Collections.emptyList()); + when(desc.getMinReductionHashAggr()).thenReturn(defaultMinReduction); + when(desc.getMinReductionHashAggrLowerBound()).thenReturn(defaultMinReductionLowerBound); + + when(stats.getBasicStatsState()).thenReturn(State.COMPLETE); + when(stats.getColumnStatsState()).thenReturn(State.COMPLETE); + + when(parent.getStatistics()).thenReturn(parentStats); + when(parentStats.getNumRows()).thenReturn(1000L); + + when(op.getParentOperators()).thenReturn(Arrays.asList(parent)); + + return op; + } +} diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSortedDynPartitionOptimizer.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSortedDynPartitionOptimizer.java new file mode 100644 index 000000000000..64d12657a452 --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSortedDynPartitionOptimizer.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.optimizer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Stream; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.ColumnInfo; +import org.apache.hadoop.hive.ql.exec.FileSinkOperator; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.RowSchema; +import org.apache.hadoop.hive.ql.parse.ParseContext; +import org.apache.hadoop.hive.ql.plan.ColStatistics; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.plan.Statistics; +import org.apache.hadoop.hive.ql.stats.StatsUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.MockedStatic; + +class TestSortedDynPartitionOptimizer { + + @ParameterizedTest(name = "{0}") + @MethodSource("computePartCardinalityColumnCases") + void testComputePartCardinalityColumnBranch( + String scenarioName, long[] ndvs, boolean firstStatNull, long expected) { + SortedDynPartitionOptimizer.SortedDynamicPartitionProc proc = newProc(null); + + Statistics tStats = mock(Statistics.class); + Operator fsParent = mock(FileSinkOperator.class); + RowSchema schema = mock(RowSchema.class); + when(fsParent.getSchema()).thenReturn(schema); + + List sig = new ArrayList<>(); + for (int i = 0; i < ndvs.length; i++) { + String colName = "p" + i; + ColumnInfo ci = mock(ColumnInfo.class); + when(ci.getInternalName()).thenReturn(colName); + sig.add(ci); + + if (i == 0 && firstStatNull) { + when(tStats.getColumnStatisticsFromColName(colName)).thenReturn(null); + } else { + ColStatistics cs = new ColStatistics(colName, "int"); + cs.setCountDistint(ndvs[i]); + when(tStats.getColumnStatisticsFromColName(colName)).thenReturn(cs); + } + } + when(schema.getSignature()).thenReturn(sig); + + List partitionPos = new ArrayList<>(); + for (int i = 0; i < ndvs.length; i++) { + partitionPos.add(i); + } + + long result = proc.computePartCardinality( + partitionPos, Collections.emptyList(), tStats, fsParent, new ArrayList<>()); + + assertEquals(expected, result, scenarioName); + } + + private static Stream computePartCardinalityColumnCases() { + return Stream.of( + // All known positive NDVs - product computed + Arguments.of("twoPositiveColumns", new long[] {10L, 5L}, false, 50L), + Arguments.of("singlePositiveColumn", new long[] {42L}, false, 42L), + Arguments.of("threeColumnsCompound", new long[] {3L, 4L, 5L}, false, 60L), + // HIVE-29625: NDV<0 is unknown - returns -1 + Arguments.of("unknownNDVShortCircuits", new long[] {10L, -1L, 5L}, false, -1L), + Arguments.of("firstUnknownShortCircuits", new long[] {-1L, 10L}, false, -1L), + Arguments.of("singleUnknownColumn", new long[] {-1L}, false, -1L), + // Verified zero NDV (HIVE-29625 disambiguation) - falls through to multiplication + Arguments.of("verifiedZeroProducesZero", new long[] {10L, 0L, 5L}, false, 0L), + Arguments.of("singleVerifiedZero", new long[] {0L}, false, 0L), + // Missing stats (partStats == null) - returns -1 + Arguments.of("nullStatsShortCircuits", new long[] {0L, 10L}, true, -1L) + ); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("computePartCardinalityExprCases") + void testComputePartCardinalityCustomExprBranch( + String scenarioName, long[] ndvs, boolean firstStatNull, long expected) { + HiveConf conf = new HiveConf(); + ParseContext parseCtx = mock(ParseContext.class); + when(parseCtx.getConf()).thenReturn(conf); + SortedDynPartitionOptimizer.SortedDynamicPartitionProc proc = newProc(parseCtx); + + Statistics tStats = mock(Statistics.class); + Operator fsParent = mock(FileSinkOperator.class); + ArrayList allRSCols = new ArrayList<>(); + List, ExprNodeDesc>> exprs = new ArrayList<>(); + List resolvedExprs = new ArrayList<>(); + for (int i = 0; i < ndvs.length; i++) { + ExprNodeDesc resolved = mock(ExprNodeDesc.class); + resolvedExprs.add(resolved); + exprs.add(cols -> resolved); + } + + try (MockedStatic stub = mockStatic(StatsUtils.class)) { + for (int i = 0; i < ndvs.length; i++) { + final int idx = i; + final ColStatistics cs; + if (idx == 0 && firstStatNull) { + cs = null; + } else { + cs = new ColStatistics("e" + idx, "int"); + cs.setCountDistint(ndvs[idx]); + } + stub.when(() -> StatsUtils.getColStatisticsFromExpression(eq(conf), eq(tStats), eq(resolvedExprs.get(idx)))) + .thenReturn(cs); + } + + long result = proc.computePartCardinality( + Collections.emptyList(), exprs, tStats, fsParent, allRSCols); + + assertEquals(expected, result, scenarioName); + } + } + + private static Stream computePartCardinalityExprCases() { + return Stream.of( + Arguments.of("singleKnownExpr", new long[] {7L}, false, 7L), + Arguments.of("twoKnownExprsMultiply", new long[] {3L, 4L}, false, 12L), + // HIVE-29625: NDV<0 from expression stats short-circuits + Arguments.of("unknownExprStatsShortCircuit", new long[] {5L, -1L}, false, -1L), + Arguments.of("firstExprUnknown", new long[] {-1L, 5L}, false, -1L), + // Verified zero from expression stats - falls through to multiplication + Arguments.of("verifiedZeroExprProducesZero", new long[] {5L, 0L}, false, 0L), + // Null expression stats (StatsUtils returned null) - returns -1 + Arguments.of("nullExprStatsShortCircuits", new long[] {0L, 5L}, true, -1L) + ); + } + + @Test + void testComputePartCardinalityBothEmptyReturnsZero() { + SortedDynPartitionOptimizer.SortedDynamicPartitionProc proc = newProc(null); + long result = proc.computePartCardinality( + Collections.emptyList(), Collections.emptyList(), + mock(Statistics.class), mock(FileSinkOperator.class), new ArrayList<>()); + assertEquals(0L, result, "Both partitionPos and customPartitionExprs empty -> 0"); + } + + private static SortedDynPartitionOptimizer.SortedDynamicPartitionProc newProc(ParseContext parseCtx) { + SortedDynPartitionOptimizer outer = new SortedDynPartitionOptimizer(); + return outer.new SortedDynamicPartitionProc(parseCtx); + } +} diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestHiveRelMdDistinctRowCount.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestHiveRelMdDistinctRowCount.java new file mode 100644 index 000000000000..60cc0cc28d5e --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestHiveRelMdDistinctRowCount.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.optimizer.calcite.stats; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan; +import org.apache.hadoop.hive.ql.plan.ColStatistics; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +class TestHiveRelMdDistinctRowCount { + + @ParameterizedTest(name = "{0}") + @MethodSource("getDistinctRowCountCases") + void testGetDistinctRowCountHiveTableScan( + String scenarioName, long[] ndvs, double rowCount, double expected) { + HiveTableScan htRel = mock(HiveTableScan.class); + RelMetadataQuery mq = mock(RelMetadataQuery.class); + + when(htRel.getColStat(any())).thenReturn(buildColStats(ndvs)); + when(mq.getRowCount(htRel)).thenReturn(rowCount); + + HiveRelMdDistinctRowCount provider = new HiveRelMdDistinctRowCount(); + Double result = provider.getDistinctRowCount(htRel, mq, ImmutableBitSet.of(0), null); + + assertEquals(expected, result, + "getDistinctRowCount for ndvs=" + Arrays.toString(ndvs) + " rowCount=" + rowCount); + } + + private static Stream getDistinctRowCountCases() { + return Stream.of( + // All positive, product fits under row count + Arguments.of("allPositiveProductUnderRowCount", new long[] {10L, 5L}, 1000.0, 50.0), + Arguments.of("singlePositive", new long[] {42L}, 1000.0, 42.0), + // Product exceeds row count -> capped + Arguments.of("productCappedAtRowCount", new long[] {2000L, 50L}, 1000.0, 1000.0), + Arguments.of("singlePositiveExceedsRowCount", new long[] {5000L}, 1000.0, 1000.0), + // Verified-zero NDV in any column triggers the <=0 early-exit + Arguments.of("verifiedZeroInAnyColumn", new long[] {10L, 0L, 5L}, 1000.0, 0.0), + Arguments.of("verifiedZeroFirstShortCircuits", new long[] {0L, 10L}, 1000.0, 0.0), + Arguments.of("verifiedZeroAlone", new long[] {0L}, 1000.0, 0.0), + // Unknown NDV (-1) in any column triggers the <=0 early-exit (HIVE-29625) + Arguments.of("unknownInAnyColumn", new long[] {10L, -1L, 5L}, 1000.0, 0.0), + Arguments.of("unknownFirstShortCircuits", new long[] {-1L, 10L}, 1000.0, 0.0), + Arguments.of("unknownAlone", new long[] {-1L}, 1000.0, 0.0), + // Mixed verified-zero and unknown: both produce 0.0 regardless of order + Arguments.of("unknownThenVerifiedZero", new long[] {-1L, 0L}, 1000.0, 0.0), + Arguments.of("verifiedZeroThenUnknown", new long[] {0L, -1L}, 1000.0, 0.0), + // Empty column list - loop doesn't execute, fall through to Math.min(1.0, rowCount) + Arguments.of("emptyColStatsFallsThroughTo1", new long[] {}, 1000.0, 1.0), + Arguments.of("emptyColStatsCappedByLowRowCount", new long[] {}, 0.5, 0.5) + ); + } + + private static List buildColStats(long[] ndvs) { + List stats = new ArrayList<>(); + for (int i = 0; i < ndvs.length; i++) { + ColStatistics cs = new ColStatistics("c" + i, "int"); + cs.setCountDistint(ndvs[i]); + stats.add(cs); + } + return stats; + } +} diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/stats/annotation/TestStatsRulesProcFactory.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/stats/annotation/TestStatsRulesProcFactory.java index 4d9d351af8f1..1bcce8d02a12 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/stats/annotation/TestStatsRulesProcFactory.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/stats/annotation/TestStatsRulesProcFactory.java @@ -23,14 +23,25 @@ import org.apache.hadoop.hive.common.type.Date; import org.apache.hadoop.hive.common.type.Timestamp; import org.apache.hadoop.hive.metastore.StatisticsTestUtils; +import org.apache.hadoop.hive.ql.Context; +import org.apache.hadoop.hive.ql.exec.ColumnInfo; +import org.apache.hadoop.hive.ql.exec.GroupByOperator; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.RowSchema; +import org.apache.hadoop.hive.ql.exec.tez.DagUtils; +import org.apache.hadoop.hive.ql.parse.ParseContext; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.plan.ColStatistics; import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; +import org.apache.hadoop.hive.ql.plan.GroupByDesc; +import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.Statistics; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBetween; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrGreaterThan; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrLessThan; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPGreaterThan; @@ -39,21 +50,35 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.hive.ql.plan.AggregationDesc; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; -import org.junit.Test; +import org.apache.hadoop.yarn.api.records.Resource; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentCaptor; +import org.mockito.MockedStatic; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; import org.apache.hadoop.hive.ql.exec.CommonJoinOperator; import org.apache.hadoop.hive.ql.plan.JoinCondDesc; import org.apache.hadoop.hive.ql.plan.JoinDesc; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.apache.hadoop.hive.ql.optimizer.stats.annotation.StatsRulesProcFactory.FilterStatsRule.extractFloatFromLiteralValue; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; public class TestStatsRulesProcFactory { @@ -102,6 +127,241 @@ public void testComparisonRowCountInvalidKll() throws SemanticException { assertEquals((VALUES.length + numNulls) / 3, numRows); } + @Test + public void testEvaluateInExprWithUnknownNDVAppliesHalfFactor() throws SemanticException { + // HIVE-29625: when the column's NDV is unknown (-1), the IN filter takes + // factor *= 0.5 and continues (rather than the old behavior of treating + // dvs==0 as unknown). currNumRows=13, factor=0.5, inFactor=1.0 (default). + Statistics stats = createStatistics(VALUES, 0); + stats.getColumnStats().get(0).setCountDistint(-1); // force unknown NDV + + AnnotateStatsProcCtx ctx = spy(new AnnotateStatsProcCtx(null)); + when(ctx.getConf()).thenReturn(new HiveConf()); + + ExprNodeDesc inExpr = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo, + new GenericUDFIn(), + Arrays.asList(COL_EXPR, createExprNodeConstantDesc(1), createExprNodeConstantDesc(2))); + + long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( + stats, inExpr, ctx, Arrays.asList(COL_NAME), null, VALUES.length); + + assertEquals(Math.round(VALUES.length * 0.5), numRows); + } + + @Test + public void testEvaluateEqualWithUnknownNDVUsesHalfRows() throws SemanticException { + // HIVE-29625: col = const where col.NDV=-1 (unknown) falls back to numRows/2. + // VALUES.length=13, expected = 13/2 = 6 (long division). + Statistics stats = createStatistics(VALUES, 0); + stats.getColumnStats().get(0).setCountDistint(-1); + + AnnotateStatsProcCtx ctx = spy(new AnnotateStatsProcCtx(null)); + when(ctx.getConf()).thenReturn(new HiveConf()); + + ExprNodeDesc eqExpr = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo, + new GenericUDFOPEqual(), + Arrays.asList(COL_EXPR, createExprNodeConstantDesc(1))); + + long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( + stats, eqExpr, ctx, Arrays.asList(COL_NAME), null, VALUES.length); + + assertEquals(6, numRows); + } + + @Test + public void testEvaluateEqualWithVerifiedZeroNDVReturnsZero() throws SemanticException { + // HIVE-29625: col = const where col.NDV=0 (verified zero) returns 0 rows. + Statistics stats = createStatistics(VALUES, 0); + stats.getColumnStats().get(0).setCountDistint(0); + + AnnotateStatsProcCtx ctx = spy(new AnnotateStatsProcCtx(null)); + when(ctx.getConf()).thenReturn(new HiveConf()); + + ExprNodeDesc eqExpr = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo, + new GenericUDFOPEqual(), + Arrays.asList(COL_EXPR, createExprNodeConstantDesc(1))); + + long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( + stats, eqExpr, ctx, Arrays.asList(COL_NAME), null, VALUES.length); + + assertEquals(0, numRows); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("groupByFinalCases") + public void testGroupByStatsRuleFinalCardinality(String name, long keyNdv, long expectedRows) throws SemanticException { + assertGroupByFinalCardinality(keyNdv, expectedRows); + } + + private static Stream groupByFinalCases() { + return Stream.of( + Arguments.of("ndvUnknownAppliesFallback", -1L, 500L), + Arguments.of("ndvVerifiedZeroFlowsThroughClampedToOne", 0L, 1L), + Arguments.of("ndvKnownUsesProduct", 10L, 10L) + ); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("groupByHashCases") + public void testCheckMapSideAggregationHashCardinality(String name, long keyNdv, long expectedRows) throws SemanticException { + assertGroupByHashCardinality(keyNdv, expectedRows); + } + + private static Stream groupByHashCases() { + return Stream.of( + Arguments.of("ndvUnknownFallsBackToHalfParent", -1L, 500L), + Arguments.of("ndvKnownUsesProduct", 100L, 100L) + ); + } + + private void assertGroupByHashCardinality(long keyNdv, long expectedRows) throws SemanticException { + Statistics parentStats = new Statistics(1000, 8000, 0, 0); + parentStats.setBasicStatsState(Statistics.State.COMPLETE); + parentStats.setColumnStatsState(Statistics.State.COMPLETE); + ColStatistics keyCol = new ColStatistics("k", "int"); + keyCol.setCountDistint(keyNdv); + keyCol.setNumNulls(0); + parentStats.setColumnStats(Collections.singletonList(keyCol)); + + @SuppressWarnings("unchecked") + Operator parent = mock(Operator.class); + when(parent.getStatistics()).thenReturn(parentStats); + when(parent.getParentOperators()).thenReturn(Collections.emptyList()); + + GroupByDesc gbyDesc = mock(GroupByDesc.class); + when(gbyDesc.getMode()).thenReturn(GroupByDesc.Mode.HASH); + when(gbyDesc.getAggregators()).thenReturn(Collections.emptyList()); + when(gbyDesc.isGroupingSetsPresent()).thenReturn(false); + ExprNodeColumnDesc keyExpr = new ExprNodeColumnDesc(TypeInfoFactory.intTypeInfo, "k", "table", false); + when(gbyDesc.getKeys()).thenReturn(Collections.singletonList(keyExpr)); + + GroupByOperator gop = mock(GroupByOperator.class); + when(gop.getParentOperators()).thenReturn(Collections.singletonList(parent)); + when(gop.getConf()).thenReturn(gbyDesc); + Map colExprMap = new HashMap<>(); + colExprMap.put("_col0", keyExpr); + when(gop.getColumnExprMap()).thenReturn(colExprMap); + RowSchema rs = mock(RowSchema.class); + ColumnInfo colInfo = new ColumnInfo("_col0", TypeInfoFactory.intTypeInfo, "table", false); + when(rs.getSignature()).thenReturn(Collections.singletonList(colInfo)); + when(rs.getColumnInfo("_col0")).thenReturn(colInfo); + when(gop.getSchema()).thenReturn(rs); + + Context context = mock(Context.class); + HiveConf conf = new HiveConf(); + conf.setBoolVar(HiveConf.ConfVars.HIVE_QUERY_REEXECUTION_ENABLED, false); + when(context.getConf()).thenReturn(conf); + ParseContext pctx = mock(ParseContext.class); + when(pctx.getContext()).thenReturn(context); + AnnotateStatsProcCtx ctx = spy(new AnnotateStatsProcCtx(null)); + when(ctx.getConf()).thenReturn(conf); + when(ctx.getParseContext()).thenReturn(pctx); + + // checkMapSideAggregation calls DagUtils.getContainerResource(conf) to compute + // the available hash-aggregation memory. Stub it to a generous 1024 MB so the + // estimated hash table size stays well under the threshold and hashAgg is selected. + try (MockedStatic dagMock = mockStatic(DagUtils.class)) { + Resource res = mock(Resource.class); + when(res.getMemorySize()).thenReturn(1024L); + dagMock.when(() -> DagUtils.getContainerResource(any())).thenReturn(res); + + new StatsRulesProcFactory.GroupByStatsRule().process(gop, null, ctx, (Object[]) null); + } + + ArgumentCaptor captor = ArgumentCaptor.forClass(Statistics.class); + verify(gop).setStatistics(captor.capture()); + assertEquals(expectedRows, captor.getValue().getNumRows()); + } + + private void assertGroupByFinalCardinality(long keyNdv, long expectedRows) throws SemanticException { + Statistics parentStats = new Statistics(1000, 8000, 0, 0); + parentStats.setBasicStatsState(Statistics.State.COMPLETE); + parentStats.setColumnStatsState(Statistics.State.COMPLETE); + ColStatistics keyCol = new ColStatistics("k", "int"); + keyCol.setCountDistint(keyNdv); + keyCol.setNumNulls(0); + parentStats.setColumnStats(Collections.singletonList(keyCol)); + + @SuppressWarnings("unchecked") + Operator parent = mock(Operator.class); + when(parent.getStatistics()).thenReturn(parentStats); + when(parent.getParentOperators()).thenReturn(Collections.emptyList()); + + GroupByDesc gbyDesc = mock(GroupByDesc.class); + when(gbyDesc.getMode()).thenReturn(GroupByDesc.Mode.FINAL); + when(gbyDesc.getAggregators()).thenReturn(Collections.emptyList()); + when(gbyDesc.isGroupingSetsPresent()).thenReturn(false); + ExprNodeColumnDesc keyExpr = new ExprNodeColumnDesc(TypeInfoFactory.intTypeInfo, "k", "table", false); + when(gbyDesc.getKeys()).thenReturn(Collections.singletonList(keyExpr)); + + GroupByOperator gop = mock(GroupByOperator.class); + when(gop.getParentOperators()).thenReturn(Collections.singletonList(parent)); + when(gop.getConf()).thenReturn(gbyDesc); + Map colExprMap = new HashMap<>(); + colExprMap.put("_col0", keyExpr); + when(gop.getColumnExprMap()).thenReturn(colExprMap); + RowSchema rs = mock(RowSchema.class); + ColumnInfo colInfo = new ColumnInfo("_col0", TypeInfoFactory.intTypeInfo, "table", false); + when(rs.getSignature()).thenReturn(Collections.singletonList(colInfo)); + when(rs.getColumnInfo("_col0")).thenReturn(colInfo); + when(gop.getSchema()).thenReturn(rs); + + Context context = mock(Context.class); + HiveConf conf = new HiveConf(); + conf.setBoolVar(HiveConf.ConfVars.HIVE_QUERY_REEXECUTION_ENABLED, false); + when(context.getConf()).thenReturn(conf); + ParseContext pctx = mock(ParseContext.class); + when(pctx.getContext()).thenReturn(context); + AnnotateStatsProcCtx ctx = spy(new AnnotateStatsProcCtx(null)); + when(ctx.getConf()).thenReturn(conf); + when(ctx.getParseContext()).thenReturn(pctx); + + new StatsRulesProcFactory.GroupByStatsRule().process(gop, null, ctx, (Object[]) null); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Statistics.class); + verify(gop).setStatistics(captor.capture()); + assertEquals(expectedRows, captor.getValue().getNumRows()); + } + + @Test + public void testEvaluateEqualWithKnownNDVUsesUniformDistribution() throws SemanticException { + // Regression check: col = const where col.NDV=7 returns round(13/7)=2 rows. + // VALUES has 7 distinct values, so createStatistics sets NDV=7. + Statistics stats = createStatistics(VALUES, 0); + + AnnotateStatsProcCtx ctx = spy(new AnnotateStatsProcCtx(null)); + when(ctx.getConf()).thenReturn(new HiveConf()); + + ExprNodeDesc eqExpr = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo, + new GenericUDFOPEqual(), + Arrays.asList(COL_EXPR, createExprNodeConstantDesc(1))); + + long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( + stats, eqExpr, ctx, Arrays.asList(COL_NAME), null, VALUES.length); + + assertEquals(2, numRows); + } + + @Test + public void testEvaluateInExprWithVerifiedZeroNDVReturnsZero() throws SemanticException { + // HIVE-29625: when the column's NDV is verified zero (0), the IN filter + // sets factor=0 and breaks out of the loop, so no rows match. + Statistics stats = createStatistics(VALUES, 0); + stats.getColumnStats().get(0).setCountDistint(0); // force verified-zero NDV + + AnnotateStatsProcCtx ctx = spy(new AnnotateStatsProcCtx(null)); + when(ctx.getConf()).thenReturn(new HiveConf()); + + ExprNodeDesc inExpr = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo, + new GenericUDFIn(), + Arrays.asList(COL_EXPR, createExprNodeConstantDesc(1), createExprNodeConstantDesc(2))); + + long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( + stats, inExpr, ctx, Arrays.asList(COL_NAME), null, VALUES.length); + + assertEquals(0, numRows); + } + @Test public void testComparisonRowCountLessThan() throws SemanticException { long numNulls = 2; @@ -676,10 +936,10 @@ public void testComputeAggregateColumnMinMaxWithUnknownNumNulls() throws Semanti // Verify: With the fix, COUNT Range should be (0, 100) // numNulls=-1 is treated as 0, so valuesCount = 100 - 0 = 100 // Without the fix, valuesCount = 100 - (-1) = 101 (WRONG) - assertNotNull("Range should be set on COUNT column", cs.getRange()); - assertEquals("COUNT min should be 0", 0L, ((Number) cs.getRange().minValue).longValue()); - assertEquals("COUNT max should be 100 (numRows), not 101", - 100L, ((Number) cs.getRange().maxValue).longValue()); + assertNotNull(cs.getRange(), "Range should be set on COUNT column"); + assertEquals(0L, ((Number) cs.getRange().minValue).longValue(), "COUNT min should be 0"); + assertEquals(100L, ((Number) cs.getRange().maxValue).longValue(), + "COUNT max should be 100 (numRows), not 101"); } @Test @@ -708,10 +968,71 @@ public void testComputeAggregateColumnMinMaxWithKnownNumNulls() throws SemanticE cs, conf, agg, "bigint", parentStats); // With known numNulls=20, valuesCount = 100 - 20 = 80 - assertNotNull("Range should be set", cs.getRange()); + assertNotNull(cs.getRange(), "Range should be set"); assertEquals(0L, ((Number) cs.getRange().minValue).longValue()); - assertEquals("COUNT max should be 80 (numRows - numNulls)", - 80L, ((Number) cs.getRange().maxValue).longValue()); + assertEquals(80L, ((Number) cs.getRange().maxValue).longValue(), + "COUNT max should be 80 (numRows - numNulls)"); + } + + @Test + public void testComputeAggregateColumnMinMaxDistinctWithUnknownNDVReturnsEarly() throws SemanticException { + // HIVE-29625: for COUNT(DISTINCT col), valuesCount = parentCS.getCountDistint(). + // When that NDV is -1 (unknown), the new guard returns early to avoid building + // a Range with a negative maxValue. + ColStatistics cs = new ColStatistics("_col0", "bigint"); + HiveConf conf = new HiveConf(); + + ColStatistics parentColStats = new ColStatistics("val", "int"); + parentColStats.setNumNulls(0); + parentColStats.setCountDistint(-1); // unknown NDV + parentColStats.setRange(1, 100); + + Statistics parentStats = new Statistics(100, 400, 400, 400); + parentStats.addToColumnStats(Collections.singletonList(parentColStats)); + + ExprNodeColumnDesc colExpr = new ExprNodeColumnDesc( + TypeInfoFactory.intTypeInfo, "val", "t", false); + AggregationDesc agg = new AggregationDesc(); + agg.setGenericUDAFName("count"); + agg.setParameters(Collections.singletonList(colExpr)); + agg.setDistinct(true); + agg.setMode(GenericUDAFEvaluator.Mode.COMPLETE); + + StatsRulesProcFactory.GroupByStatsRule.computeAggregateColumnMinMax( + cs, conf, agg, "bigint", parentStats); + + assertNull(cs.getRange(), "Range should NOT be set when DISTINCT NDV is unknown"); + } + + @Test + public void testComputeAggregateColumnMinMaxDistinctWithKnownNDVSetsRange() throws SemanticException { + // Regression: COUNT(DISTINCT col) with known parentCS.NDV=50 sets Range(0, 50). + ColStatistics cs = new ColStatistics("_col0", "bigint"); + HiveConf conf = new HiveConf(); + + ColStatistics parentColStats = new ColStatistics("val", "int"); + parentColStats.setNumNulls(0); + parentColStats.setCountDistint(50); + parentColStats.setRange(1, 100); + + Statistics parentStats = new Statistics(100, 400, 400, 400); + parentStats.addToColumnStats(Collections.singletonList(parentColStats)); + + ExprNodeColumnDesc colExpr = new ExprNodeColumnDesc( + TypeInfoFactory.intTypeInfo, "val", "t", false); + AggregationDesc agg = new AggregationDesc(); + agg.setGenericUDAFName("count"); + agg.setParameters(Collections.singletonList(colExpr)); + agg.setDistinct(true); + agg.setMode(GenericUDAFEvaluator.Mode.COMPLETE); + + StatsRulesProcFactory.GroupByStatsRule.computeAggregateColumnMinMax( + cs, conf, agg, "bigint", parentStats); + + assertNotNull(cs.getRange(), "Range should be set when DISTINCT NDV is known"); + assertEquals(0L, ((Number) cs.getRange().minValue).longValue()); + assertEquals(50L, ((Number) cs.getRange().maxValue).longValue(), + "COUNT DISTINCT max should equal the NDV (50)"); } /** @@ -749,7 +1070,101 @@ public void testUpdateNumNullsPreservesUnknownNumNulls() { joinStatsRule.updateNumNulls(colStats, 100L, 100L, 1000L, 0L, mockJop); // Assert that numNulls is still -1 (unchanged) - assertEquals("Unknown numNulls (-1) should be preserved after updateNumNulls", - -1L, colStats.getNumNulls()); + assertEquals(-1L, colStats.getNumNulls(), + "Unknown numNulls (-1) should be preserved after updateNumNulls"); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("calculateUnmatchedRowsForOuterCases") + public void testCalculateUnmatchedRowsForOuter( + String name, long ndv, long distinctUnmatched, long expected) { + assertCalculateUnmatchedRowsForOuter(ndv, distinctUnmatched, expected); + } + + private static Stream calculateUnmatchedRowsForOuterCases() { + return Stream.of( + Arguments.of("distinctValUnknownReturnsInputRowCount", -1L, 5L, 100L), + Arguments.of("distinctValVerifiedZeroReturnsInputRowCount", 0L, 5L, 100L), + Arguments.of("distinctUnmatchedUnknownReturnsInputRowCount", 10L, -1L, 100L), + Arguments.of("distinctUnmatchedExceedsReturnsInputRowCount", 10L, 15L, 100L), + Arguments.of("normalCaseDivides", 10L, 2L, 20L) + ); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("computeRowCountAssumingInnerJoinCases") + public void testComputeRowCountAssumingInnerJoin(String name, long denom, long expected) { + assertComputeRowCountAssumingInnerJoin(denom, expected); + } + + private static Stream computeRowCountAssumingInnerJoinCases() { + return Stream.of( + Arguments.of("denomPositiveDivides", 10L, 2000L), + Arguments.of("denomZeroClampsToOne", 0L, 20000L), + Arguments.of("denomNegativeClampsToOne", -1L, 20000L) + ); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("updateColStatsCases") + public void testUpdateColStats(String name, long initialNdv, long expectedNdv) { + ColStatistics cs = new ColStatistics("k", "int"); + cs.setCountDistint(initialNdv); + cs.setNumNulls(0); + Statistics stats = new Statistics(1000, 8000, 0, 0); + stats.setColumnStats(Collections.singletonList(cs)); + + Map reversedExprs = new HashMap<>(); + reversedExprs.put("k", (byte) 0); + JoinCondDesc joinCond = mock(JoinCondDesc.class); + when(joinCond.getType()).thenReturn(JoinDesc.INNER_JOIN); + JoinDesc joinDesc = mock(JoinDesc.class); + when(joinDesc.getReversedExprs()).thenReturn(reversedExprs); + when(joinDesc.getConds()).thenReturn(new JoinCondDesc[]{joinCond}); + when(joinDesc.getJoinKeys()).thenReturn(new ExprNodeDesc[][]{}); + @SuppressWarnings("unchecked") + CommonJoinOperator jop = mock(CommonJoinOperator.class); + when(jop.getConf()).thenReturn(joinDesc); + RowSchema schema = mock(RowSchema.class); + when(schema.getColumnNames()).thenReturn(Collections.singletonList("k")); + when(schema.getSignature()).thenReturn(Collections.emptyList()); + when(jop.getSchema()).thenReturn(schema); + Map rowCountParents = new HashMap<>(); + rowCountParents.put(0, 1000L); + HiveConf conf = new HiveConf(); + conf.setBoolVar(HiveConf.ConfVars.HIVE_STATS_JOIN_NDV_READJUSTMENT, false); + + new StatsRulesProcFactory.JoinStatsRule().updateColStats( + conf, stats, 0L, 0L, 500L, jop, rowCountParents); + + assertEquals(expectedNdv, cs.getCountDistint()); + } + + private static Stream updateColStatsCases() { + return Stream.of( + Arguments.of("unknownNdvSkipsMath", -1L, -1L), + Arguments.of("knownNdvScaledByRatio", 100L, 50L) + ); + } + + private void assertComputeRowCountAssumingInnerJoin(long denom, long expected) { + StatsRulesProcFactory.JoinStatsRule rule = new StatsRulesProcFactory.JoinStatsRule(); + long actual = rule.computeRowCountAssumingInnerJoin(Arrays.asList(100L, 200L), denom, null); + assertEquals(expected, actual); + } + + private void assertCalculateUnmatchedRowsForOuter(long ndv, long distinctUnmatched, long expected) { + HiveConf conf = new HiveConf(); + ColStatistics cs = new ColStatistics("k", "int"); + cs.setCountDistint(ndv); + cs.setNumNulls(0); + Statistics stats = new Statistics(100, 400, 0, 0); + stats.setColumnStats(Collections.singletonList(cs)); + + StatsRulesProcFactory.JoinStatsRule rule = new StatsRulesProcFactory.JoinStatsRule(); + long actual = rule.calculateUnmatchedRowsForOuter( + conf, 100L, Collections.singletonList("k"), stats, distinctUnmatched); + + assertEquals(expected, actual); } } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/stats/TestStatsUtils.java b/ql/src/test/org/apache/hadoop/hive/ql/stats/TestStatsUtils.java index c009472fed0a..679bf28734ed 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/stats/TestStatsUtils.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/stats/TestStatsUtils.java @@ -25,6 +25,7 @@ import java.lang.reflect.Field; import java.lang.reflect.Modifier; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Set; @@ -32,13 +33,16 @@ import org.apache.commons.lang3.reflect.FieldUtils; import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.BinaryColumnStatsData; import org.apache.hadoop.hive.metastore.api.BooleanColumnStatsData; import org.apache.hadoop.hive.metastore.api.ColumnStatisticsData; import org.apache.hadoop.hive.metastore.api.ColumnStatisticsObj; import org.apache.hadoop.hive.metastore.api.Date; import org.apache.hadoop.hive.metastore.api.DateColumnStatsData; +import org.apache.hadoop.hive.metastore.api.DecimalColumnStatsData; import org.apache.hadoop.hive.metastore.api.DoubleColumnStatsData; import org.apache.hadoop.hive.metastore.api.LongColumnStatsData; +import org.apache.hadoop.hive.metastore.api.StringColumnStatsData; import org.apache.hadoop.hive.metastore.api.Timestamp; import org.apache.hadoop.hive.metastore.api.TimestampColumnStatsData; import org.apache.hadoop.hive.ql.plan.ColStatistics; @@ -285,6 +289,204 @@ void testStatisticsAddToColumnStatsPropagatesUnknownFromExisting() { assertEquals(-1, merged.getNumNulls(), "Unknown numNulls (-1) should be propagated when existing is unknown"); } + @ParameterizedTest(name = "{0}") + @MethodSource("addToColumnStatsCountDistinctCases") + void testStatisticsAddToColumnStatsCountDistinctMerge( + String scenarioName, long existingNdv, long incomingNdv, long expectedMergedNdv) { + Statistics stats = new Statistics(1000, 8000, 0, 0); + ColStatistics existing = createColStats("col1", existingNdv, 0); + stats.setColumnStats(Collections.singletonList(existing)); + + ColStatistics incoming = createColStats("col1", incomingNdv, 0); + stats.addToColumnStats(Collections.singletonList(incoming)); + + ColStatistics merged = stats.getColumnStatisticsFromColName("col1"); + assertEquals(expectedMergedNdv, merged.getCountDistint(), + "countDistinct after merge"); + } + + private static Stream addToColumnStatsCountDistinctCases() { + return Stream.of( + Arguments.of("incomingUnknownPropagates", 5L, -1L, -1L), + Arguments.of("existingUnknownPropagates", -1L, 5L, -1L), + Arguments.of("bothUnknownStaysUnknown", -1L, -1L, -1L), + Arguments.of("maxPicksIncomingWhenHigher", 3L, 7L, 7L), + Arguments.of("maxPicksExistingWhenHigher", 7L, 3L, 7L) + ); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("containsUnknownNDVCases") + void testContainsUnknownNDV(String scenarioName, List input, boolean expected) { + assertEquals(expected, StatsUtils.containsUnknownNDV(input), + "containsUnknownNDV(" + input + ")"); + } + + private static Stream containsUnknownNDVCases() { + return Stream.of( + Arguments.of("allPositive", Arrays.asList(1L, 2L, 3L), false), + Arguments.of("containsZero_NotUnknown", Arrays.asList(1L, 0L, 3L), false), + Arguments.of("singleUnknown", Arrays.asList(1L, -1L, 3L), true), + Arguments.of("allUnknown", Arrays.asList(-1L, -1L, -1L), true), + Arguments.of("firstIsUnknown_ShortCircuit", Arrays.asList(-1L, 2L, 3L), true), + Arguments.of("emptyList", Collections.emptyList(), false) + ); + } + + @Test + void testAddWithExpDecayReturnsUnknownWhenAnyInputIsUnknown() { + Long result = StatsUtils.addWithExpDecay(Arrays.asList(10L, -1L, 5L)); + assertEquals(-1L, result, "addWithExpDecay should propagate unknown NDV (-1) when present"); + } + + @Test + void testAddWithExpDecayComputesWhenAllInputsKnown() { + Long result = StatsUtils.addWithExpDecay(Arrays.asList(100L, 25L)); + // Exponential decay: 100 * 25^(1/2) = 100 * 5 = 500. + assertEquals(500L, result, "addWithExpDecay should return the exponential-decay denominator for known inputs"); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("computeNDVGroupingColumnsCases") + void testComputeNDVGroupingColumns(String scenarioName, List colStats, + Statistics.State parentColStatsState, boolean expDecay, long expected) { + Statistics parentStats = new Statistics(1000, 8000, 0, 0); + parentStats.setColumnStatsState(parentColStatsState); + + long result = StatsUtils.computeNDVGroupingColumns(colStats, parentStats, expDecay); + + assertEquals(expected, result, scenarioName); + } + + private static Stream computeNDVGroupingColumnsCases() { + return Stream.of( + Arguments.of("allKnownReturnsProduct", + Arrays.asList(makeColStat("c1", 10), makeColStat("c2", 20)), + Statistics.State.COMPLETE, false, 200L), + Arguments.of("unknownColumnReturnsMinusOne", + Arrays.asList(makeColStat("c1", 10), makeColStat("c2", -1)), + Statistics.State.COMPLETE, false, -1L), + Arguments.of("emptyColumnsReturnsOne", + Collections.emptyList(), + Statistics.State.COMPLETE, false, 1L), + Arguments.of("nullColStatWithCompleteParentSkipped", + Arrays.asList(null, makeColStat("c2", 10)), + Statistics.State.COMPLETE, false, 10L), + Arguments.of("nullColStatWithPartialParentReturnsMinusOne", + Arrays.asList((ColStatistics) null), + Statistics.State.PARTIAL, false, -1L), + Arguments.of("expDecayWithKnownInputs", + Arrays.asList(makeColStat("c1", 100), makeColStat("c2", 25)), + Statistics.State.COMPLETE, true, 500L), + Arguments.of("expDecayWithUnknownPropagates", + Arrays.asList(makeColStat("c1", 100), makeColStat("c2", -1)), + Statistics.State.COMPLETE, true, -1L) + ); + } + + private static ColStatistics makeColStat(String name, long ndv) { + ColStatistics cs = new ColStatistics(name, "string"); + cs.setCountDistint(ndv); + cs.setNumNulls(0); + return cs; + } + + @ParameterizedTest(name = "{0}") + @MethodSource("getColStatisticsUnsetNumDVsCases") + void testGetColStatisticsReturnsUnknownNDVWhenNumDVsNotSet( + String typeName, ColumnStatisticsData data) { + ColumnStatisticsObj cso = new ColumnStatisticsObj(); + cso.setColName("test_col"); + cso.setColType(typeName); + cso.setStatsData(data); + + ColStatistics cs = StatsUtils.getColStatistics(cso, "test_col"); + + assertNotNull(cs, "ColStatistics should not be null for " + typeName); + assertEquals(-1, cs.getCountDistint(), + "When numDVs is unset for " + typeName + ", NDV should be -1"); + } + + private static Stream getColStatisticsUnsetNumDVsCases() { + LongColumnStatsData longStats = new LongColumnStatsData(); + longStats.setNumNulls(10); + // numDVs NOT set + + DoubleColumnStatsData doubleStats = new DoubleColumnStatsData(); + doubleStats.setNumNulls(10); + + StringColumnStatsData stringStats = new StringColumnStatsData(); + stringStats.setNumNulls(10); + stringStats.setAvgColLen(5.0); + stringStats.setMaxColLen(20); + + BinaryColumnStatsData binaryStats = new BinaryColumnStatsData(); + binaryStats.setNumNulls(10); + binaryStats.setAvgColLen(5.0); + binaryStats.setMaxColLen(20); + + TimestampColumnStatsData timestampStats = new TimestampColumnStatsData(); + timestampStats.setNumNulls(10); + + DecimalColumnStatsData decimalStats = new DecimalColumnStatsData(); + decimalStats.setNumNulls(10); + + DateColumnStatsData dateStats = new DateColumnStatsData(); + dateStats.setNumNulls(10); + + return Stream.of( + Arguments.of(serdeConstants.BIGINT_TYPE_NAME, wrapLong(longStats)), + Arguments.of(serdeConstants.DOUBLE_TYPE_NAME, wrapDouble(doubleStats)), + Arguments.of(serdeConstants.STRING_TYPE_NAME, wrapString(stringStats)), + Arguments.of(serdeConstants.BINARY_TYPE_NAME, wrapBinary(binaryStats)), + Arguments.of(serdeConstants.TIMESTAMP_TYPE_NAME, wrapTimestamp(timestampStats)), + Arguments.of(serdeConstants.DECIMAL_TYPE_NAME, wrapDecimal(decimalStats)), + Arguments.of(serdeConstants.DATE_TYPE_NAME, wrapDate(dateStats)) + ); + } + + private static ColumnStatisticsData wrapLong(LongColumnStatsData s) { + ColumnStatisticsData d = new ColumnStatisticsData(); + d.setLongStats(s); + return d; + } + + private static ColumnStatisticsData wrapDouble(DoubleColumnStatsData s) { + ColumnStatisticsData d = new ColumnStatisticsData(); + d.setDoubleStats(s); + return d; + } + + private static ColumnStatisticsData wrapString(StringColumnStatsData s) { + ColumnStatisticsData d = new ColumnStatisticsData(); + d.setStringStats(s); + return d; + } + + private static ColumnStatisticsData wrapBinary(BinaryColumnStatsData s) { + ColumnStatisticsData d = new ColumnStatisticsData(); + d.setBinaryStats(s); + return d; + } + + private static ColumnStatisticsData wrapTimestamp(TimestampColumnStatsData s) { + ColumnStatisticsData d = new ColumnStatisticsData(); + d.setTimestampStats(s); + return d; + } + + private static ColumnStatisticsData wrapDecimal(DecimalColumnStatsData s) { + ColumnStatisticsData d = new ColumnStatisticsData(); + d.setDecimalStats(s); + return d; + } + + private static ColumnStatisticsData wrapDate(DateColumnStatsData s) { + ColumnStatisticsData d = new ColumnStatisticsData(); + d.setDateStats(s); + return d; + } + @Test void testGetColStatisticsBooleanWithUnknownNumTrues() { ColumnStatisticsObj cso = new ColumnStatisticsObj(); @@ -465,6 +667,53 @@ void testUpdateStatsPreservesUnknownNumNulls() { assertEquals(-1, updated.getNumNulls(), "Unknown numNulls (-1) should be preserved after scaling"); } + @Test + void testUpdateStatsMarksFilteredColumnEvenWhenNDVUnknown() { + // HIVE-29625: setFilterColumn() is now called unconditionally for affected columns, + // even when NDV is unknown (-1). The NDV math is skipped but the filter mark applies. + Statistics stats = new Statistics(1000, 8000, 0, 0); + ColStatistics cs = createColStats("col1", -1, 0); // unknown NDV + stats.setColumnStats(Collections.singletonList(cs)); + + StatsUtils.updateStats(stats, 500, true, null, Collections.singleton("col1")); + + ColStatistics updated = stats.getColumnStats().get(0); + assertEquals(true, updated.isFilteredColumn(), + "Filter-column flag should be set even when NDV is unknown"); + assertEquals(-1, updated.getCountDistint(), + "Unknown NDV (-1) should be preserved when affected column has no NDV"); + } + + @Test + void testUpdateStatsRecomputesNDVWhenAffectedAndKnown() { + // Regression check: when NDV is known and ratio <= 1.0, the NDV math still runs + // inside the new oldDV >= 0 guard. + Statistics stats = new Statistics(1000, 8000, 0, 0); + ColStatistics cs = createColStats("col1", 100, 0); // known NDV + stats.setColumnStats(Collections.singletonList(cs)); + + StatsUtils.updateStats(stats, 500, true, null, Collections.singleton("col1")); + + ColStatistics updated = stats.getColumnStats().get(0); + assertEquals(true, updated.isFilteredColumn(), + "Filter-column flag should be set for affected column with known NDV"); + // ratio = 500/1000 = 0.5 -> newDV = ceil(0.5 * 100) = 50 + assertEquals(50, updated.getCountDistint(), + "Known NDV should be scaled by the row-count ratio"); + } + + @Test + void testScaleColStatisticsPreservesUnknownCountDistint() { + // HIVE-29625: when factor < 1.0 and NDV is unknown (-1), the sentinel is preserved. + ColStatistics cs = createColStats("col1", -1, 0); // unknown NDV + List colStats = Collections.singletonList(cs); + + StatsUtils.scaleColStatistics(colStats, 0.5); + + assertEquals(-1, colStats.get(0).getCountDistint(), + "Unknown NDV (-1) should be preserved when factor < 1.0"); + } + @Test void testScaleColStatisticsPreservesUnknownNumNulls() { ColStatistics cs = createColStats("col1", 100, -1); // unknown numNulls diff --git a/ql/src/test/org/apache/hadoop/hive/ql/stats/estimator/TestPessimisticStatCombiner.java b/ql/src/test/org/apache/hadoop/hive/ql/stats/estimator/TestPessimisticStatCombiner.java index 98bc589e40d3..752907f755d8 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/stats/estimator/TestPessimisticStatCombiner.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/stats/estimator/TestPessimisticStatCombiner.java @@ -20,8 +20,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; +import java.util.stream.Stream; + import org.apache.hadoop.hive.ql.plan.ColStatistics; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; class TestPessimisticStatCombiner { @@ -136,6 +141,31 @@ void testCombineBothUnknownNumNulls() { assertEquals(-1, combined.getNumNulls(), "Both unknown should result in unknown (-1)"); } + @ParameterizedTest(name = "{0}") + @MethodSource("combineCountDistinctCases") + void testCombineCountDistinctMerge(String scenarioName, long stat1Ndv, long stat2Ndv, long expectedNdv) { + ColStatistics stat1 = createStat("col1", "int", stat1Ndv, 5, 4.0); + ColStatistics stat2 = createStat("col2", "int", stat2Ndv, 10, 4.0); + + PessimisticStatCombiner combiner = new PessimisticStatCombiner(); + combiner.add(stat1); + combiner.add(stat2); + + ColStatistics combined = combiner.getResult().get(); + assertEquals(expectedNdv, combined.getCountDistint(), + "countDistinct after PROPAGATE combine"); + } + + private static Stream combineCountDistinctCases() { + return Stream.of( + Arguments.of("firstUnknownPropagates", -1L, 50L, -1L), + Arguments.of("secondUnknownPropagates", 50L, -1L, -1L), + Arguments.of("bothUnknownStaysUnknown", -1L, -1L, -1L), + Arguments.of("picksHigherWhenSecondHigher", 30L, 50L, 50L), + Arguments.of("keepsHigherWhenFirstHigher", 50L, 30L, 50L) + ); + } + @Test void testCombineBothUnknownNumTruesAndNumFalses() { ColStatistics stat1 = createStat("col1", "boolean", 2, 5, 1.0); From 349be2f2986eeb78a48daf02dd3208d88c4c7fd7 Mon Sep 17 00:00:00 2001 From: Konstantin Bereznyakov Date: Wed, 27 May 2026 13:35:15 -0700 Subject: [PATCH 4/5] HIVE-29625: SQ feedback + better test code --- .../annotation/StatsRulesProcFactory.java | 3 +- .../optimizer/TestReduceSinkMapJoinProc.java | 56 ++-- .../TestSetHashGroupByMinReduction.java | 125 ++++----- .../TestSortedDynPartitionOptimizer.java | 48 ++-- .../stats/TestHiveRelMdDistinctRowCount.java | 4 +- .../annotation/TestStatsRulesProcFactory.java | 257 ++++++++---------- 6 files changed, 218 insertions(+), 275 deletions(-) diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java index 19985a2fe82c..fe65494e0dfe 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java @@ -1536,8 +1536,7 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, ndvProduct = parentNumRows / 2; if (LOG.isDebugEnabled()) { - LOG.debug("STATS-" + gop.toString() + ": ndvProduct unknown; falling back to " - + ndvProduct); + LOG.debug("STATS-{}: ndvProduct unknown; falling back to {}", gop, ndvProduct); } } final long maxColumnNDV = colStats.stream() diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestReduceSinkMapJoinProc.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestReduceSinkMapJoinProc.java index 415acfd10eff..75feeff971ff 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestReduceSinkMapJoinProc.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestReduceSinkMapJoinProc.java @@ -29,14 +29,18 @@ import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Stream; import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.CompilationOpContext; import org.apache.hadoop.hive.ql.Context; import org.apache.hadoop.hive.ql.exec.MapJoinOperator; +import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; +import org.apache.hadoop.hive.ql.exec.RowSchema; import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.parse.GenTezProcContext; import org.apache.hadoop.hive.ql.parse.ParseContext; @@ -44,10 +48,10 @@ import org.apache.hadoop.hive.ql.plan.ColStatistics; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.MapJoinDesc; +import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc; import org.apache.hadoop.hive.ql.plan.Statistics; import org.apache.hadoop.hive.ql.stats.StatsUtils; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -60,12 +64,15 @@ class TestReduceSinkMapJoinProc { * Master used `cs.getCountDistint() <= 0` to fall back to MAX_VALUE; HIVE-29625 uses * `cs.getCountDistint() < 0`, so verified-zero NDV no longer cascades to "no info" * but falls through to `maxKeyCount *= 0` (then clamped to 1 by later logic). + * + * A null `ndv` row represents StatsUtils.getColStatisticsFromExpression returning null + * (no derivable stat) - shares the same MAX_VALUE fallback as NDV < 0. */ @ParameterizedTest(name = "{0}") @MethodSource("keyCountFromNdvCases") void testProcessReduceSinkToHashJoinKeyCountFromNdv( - String scenarioName, long ndv, long parentRows, long expectedKeyCount) throws Exception { - invokeAndAssertKeyCount(buildColStat(ndv), parentRows, expectedKeyCount, scenarioName); + String scenarioName, Long ndv, long parentRows, long expectedKeyCount) throws Exception { + invokeAndAssertKeyCount(ndv == null ? null : buildColStat(ndv), parentRows, expectedKeyCount); } private static Stream keyCountFromNdvCases() { @@ -79,23 +86,18 @@ private static Stream keyCountFromNdvCases() { // joinConf.getParentKeyCounts().put(pos, keyCount) [only if keyCount != MAX_VALUE] return Stream.of( // NDV > 0 and below parentRows -> keyCount = NDV - Arguments.of("knownPositiveBelowParent", 10L, 1000L, 10L), + Arguments.of("knownPositiveBelowParent", 10L, 1000L, 10L), // NDV > 0 but above parentRows -> capped at parentRows - Arguments.of("knownPositiveAboveParent", 5000L, 1000L, 1000L), + Arguments.of("knownPositiveAboveParent", 5000L, 1000L, 1000L), // NDV = 0 (verified zero under HIVE-29625) -> maxKeyCount = 0 -> keyCount = 0 -> bumped to 1 - Arguments.of("verifiedZeroBumpsToOne", 0L, 1000L, 1L), + Arguments.of("verifiedZeroBumpsToOne", 0L, 1000L, 1L), // NDV = -1 (unknown under HIVE-29625) -> maxKeyCount = MAX_VALUE -> keyCount = parentRows - Arguments.of("unknownFallsBackToParent", -1L, 1000L, 1000L) + Arguments.of("unknownFallsBackToParent", -1L, 1000L, 1000L), + // cs == null (no derivable stat) -> shares the MAX_VALUE fallback path + Arguments.of("nullColStatsFallsBackToParent", null, 1000L, 1000L) ); } - @Test - void testProcessReduceSinkToHashJoinNullColStatsFallsBackToParent() throws Exception { - // cs == null case (StatsUtils could not derive a stat for the expr) -> same fallback as < 0 - invokeAndAssertKeyCount(null, 1000L, 1000L, - "Null ColStatistics falls back to parent row count (same as NDV < 0)"); - } - /** * Shared invocation harness: build a real GenTezProcContext + mocked operator chain, * stub StatsUtils.getColStatisticsFromExpression to return the given colStat for @@ -103,7 +105,7 @@ void testProcessReduceSinkToHashJoinNullColStatsFallsBackToParent() throws Excep * landed in joinConf.getParentKeyCounts() at position 0. */ private static void invokeAndAssertKeyCount( - ColStatistics csForKey, long parentRows, long expectedKeyCount, String desc) throws Exception { + ColStatistics csForKey, long parentRows, long expectedKeyCount) throws Exception { // ---- Operator chain mocks ---- ReduceSinkOperator parentRS = mock(ReduceSinkOperator.class); @@ -116,9 +118,16 @@ private static void invokeAndAssertKeyCount( when(parentRS.getConf()).thenReturn(rsConf); when(parentRS.getStatistics()).thenReturn(rsStats); + when(parentRS.getCompilationOpContext()).thenReturn(new CompilationOpContext()); Map columnExprMap = new HashMap<>(); columnExprMap.put(Utilities.ReduceField.KEY.toString() + ".k0", keyExpr); when(parentRS.getColumnExprMap()).thenReturn(columnExprMap); + Operator upstreamParent = mock(Operator.class); + when(upstreamParent.getSchema()).thenReturn(new RowSchema(Collections.emptyList())); + when(parentRS.getParentOperators()).thenReturn(Arrays.asList(upstreamParent)); + List> childOps = new ArrayList<>(); + childOps.add(mapJoinOp); + when(parentRS.getChildOperators()).thenReturn(childOps); when(mapJoinOp.getConf()).thenReturn(joinConf); @@ -130,6 +139,9 @@ private static void invokeAndAssertKeyCount( when(joinConf.getParentKeyCounts()).thenReturn(parentKeyCounts); when(joinConf.getParentToInput()).thenReturn(new LinkedHashMap<>()); when(joinConf.getParentDataSizes()).thenReturn(new LinkedHashMap<>()); + Map> keyExprMap = new HashMap<>(); + keyExprMap.put((byte) 0, Collections.emptyList()); + when(joinConf.getKeys()).thenReturn(keyExprMap); when(rsStats.getNumRows()).thenReturn(parentRows); when(rsStats.getDataSize()).thenReturn(8000L); @@ -155,21 +167,11 @@ private static void invokeAndAssertKeyCount( any(HiveConf.class), any(Statistics.class), any(ExprNodeDesc.class))) .thenReturn(csForKey); - try { - ReduceSinkMapJoinProc.processReduceSinkToHashJoin(parentRS, mapJoinOp, context); - } catch (NullPointerException expected) { - // The method continues for ~100 more lines past the keyCount put() that this - // test is verifying - dummy-operator construction, key-table-desc setup, edge - // wiring, etc. - all requiring deep operator-chain mocking irrelevant to the - // HIVE-29625 NDV branch. By the time any NPE fires from that downstream code, - // joinConf.getParentKeyCounts() already received the put() we asserted on - // (line ~229 of ReduceSinkMapJoinProc, well before the first dependency this - // harness doesn't mock). - } + ReduceSinkMapJoinProc.processReduceSinkToHashJoin(parentRS, mapJoinOp, context); } Long actual = parentKeyCounts.get(0); - assertEquals(expectedKeyCount, actual.longValue(), desc); + assertEquals(expectedKeyCount, actual.longValue()); } private static ColStatistics buildColStat(long ndv) { diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSetHashGroupByMinReduction.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSetHashGroupByMinReduction.java index 639af98f701d..3d1bdf2f71b2 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSetHashGroupByMinReduction.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSetHashGroupByMinReduction.java @@ -31,6 +31,8 @@ import java.util.Arrays; import java.util.Collections; +import java.util.function.Consumer; +import java.util.stream.Stream; import org.apache.hadoop.hive.ql.exec.GroupByOperator; import org.apache.hadoop.hive.ql.exec.Operator; @@ -41,95 +43,81 @@ import org.apache.hadoop.hive.ql.plan.Statistics; import org.apache.hadoop.hive.ql.plan.Statistics.State; import org.apache.hadoop.hive.ql.stats.StatsUtils; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.MockedStatic; class TestSetHashGroupByMinReduction { - @Test - void testProcessReturnsNullWhenNdvProductIsUnknown() throws SemanticException { - GroupByOperator op = setupCompleteHashGroupBy(0.5f, 0.1f); - GroupByDesc desc = op.getConf(); - - try (MockedStatic stub = mockStatic(StatsUtils.class)) { - stub.when(() -> StatsUtils.computeNDVGroupingColumns(any(), any(), eq(true))) - .thenReturn(-1L); + // Default-reduction tuple used across all tests. Picked so the known-positive case + // (ndvProduct=100, numRows=1000) produces a factor (0.9) strictly below the default + // (0.99), triggering setMinReductionHashAggr. + private static final float DEFAULT_MIN_REDUCTION = 0.99f; + private static final float DEFAULT_MIN_REDUCTION_LOWER_BOUND = 0.1f; - Object result = new SetHashGroupByMinReduction().process(op, null, null); - - assertNull(result, "Unknown ndvProduct (-1) makes process() return null"); - verify(desc, never()).setMinReductionHashAggr(anyFloat()); - } - } - - @Test - void testProcessProceedsWhenNdvProductIsVerifiedZero() throws SemanticException { - // HIVE-29625 disambiguation: ndvProduct == 0 (verified) now distinct from ndvProduct < 0 (unknown). - // Verified-zero means there are zero distinct group keys -> maximum reduction (factor = 1.0). - GroupByOperator op = setupCompleteHashGroupBy(0.99f, 0.5f); + /** + * NDV product drives the central HIVE-29625 disambiguation: + * ndvProduct < 0 -> unknown -> early-return null (no setMinReductionHashAggr call) + * ndvProduct == 0 -> verified zero -> factor = 1.0, NOT less than default 0.99 (no call) + * ndvProduct > 0 -> compute factor = 1 - ndvProduct/numRows, set if < default + */ + @ParameterizedTest(name = "{0}") + @MethodSource("ndvProductCases") + void testProcessByNdvProduct(String name, long ndvProduct, boolean expectSetCall) + throws SemanticException { + GroupByOperator op = setupCompleteHashGroupBy(); GroupByDesc desc = op.getConf(); try (MockedStatic stub = mockStatic(StatsUtils.class)) { stub.when(() -> StatsUtils.computeNDVGroupingColumns(any(), any(), eq(true))) - .thenReturn(0L); + .thenReturn(ndvProduct); Object result = new SetHashGroupByMinReduction().process(op, null, null); assertNull(result, "process() always returns null sentinel"); - // 1f - (0 / numRows) = 1.0; capped by lowerBound 0.5 -> stays 1.0; less than default 0.99? No, 1.0 > 0.99. - // So setMinReductionHashAggr is NOT called (newFactor not strictly less than default). - verify(desc, never()).setMinReductionHashAggr(anyFloat()); - } - } - - @Test - void testProcessProceedsWhenNdvProductIsKnownPositive() throws SemanticException { - // numRows=1000, ndvProduct=100 -> factor = 1 - 100/1000 = 0.9 - // default = 0.99 -> 0.9 < 0.99 -> setMinReductionHashAggr(0.9) - GroupByOperator op = setupCompleteHashGroupBy(0.99f, 0.1f); - GroupByDesc desc = op.getConf(); - - try (MockedStatic stub = mockStatic(StatsUtils.class)) { - stub.when(() -> StatsUtils.computeNDVGroupingColumns(any(), any(), eq(true))) - .thenReturn(100L); - - new SetHashGroupByMinReduction().process(op, null, null); - - verify(desc, atLeastOnce()).setMinReductionHashAggr(anyFloat()); + verify(desc, expectSetCall ? atLeastOnce() : never()).setMinReductionHashAggr(anyFloat()); } } - @Test - void testProcessReturnsNullWhenModeNotHash() throws SemanticException { - GroupByOperator op = setupCompleteHashGroupBy(0.99f, 0.5f); - when(op.getConf().getMode()).thenReturn(Mode.MERGEPARTIAL); - - Object result = new SetHashGroupByMinReduction().process(op, null, null); - - assertNull(result, "Non-HASH mode -> early return null"); - verify(op.getConf(), never()).setMinReductionHashAggr(anyFloat()); + private static Stream ndvProductCases() { + return Stream.of( + Arguments.of("unknownNDVEarlyReturns", -1L, false), + Arguments.of("verifiedZeroFactorTooHigh", 0L, false), + Arguments.of("knownPositiveBelowDefault", 100L, true) + ); } - @Test - void testProcessReturnsNullWhenBasicStatsIncomplete() throws SemanticException { - GroupByOperator op = setupCompleteHashGroupBy(0.99f, 0.5f); - when(op.getStatistics().getBasicStatsState()).thenReturn(State.PARTIAL); + /** + * Each gate is one of the early-return conditions in process(): non-HASH mode, + * incomplete basic stats, or incomplete column stats. All three early-return without + * touching setMinReductionHashAggr. + */ + @ParameterizedTest(name = "{0}") + @MethodSource("earlyReturnGateCases") + void testProcessEarlyReturnsOnUnsupportedState(String name, Consumer flipGate) + throws SemanticException { + GroupByOperator op = setupCompleteHashGroupBy(); + flipGate.accept(op); Object result = new SetHashGroupByMinReduction().process(op, null, null); - assertNull(result, "Incomplete basic stats -> early return null"); + assertNull(result); verify(op.getConf(), never()).setMinReductionHashAggr(anyFloat()); } - @Test - void testProcessReturnsNullWhenColumnStatsIncomplete() throws SemanticException { - GroupByOperator op = setupCompleteHashGroupBy(0.99f, 0.5f); - when(op.getStatistics().getColumnStatsState()).thenReturn(State.PARTIAL); - - Object result = new SetHashGroupByMinReduction().process(op, null, null); - - assertNull(result, "Incomplete column stats -> early return null"); - verify(op.getConf(), never()).setMinReductionHashAggr(anyFloat()); + private static Stream earlyReturnGateCases() { + return Stream.of( + Arguments.of("modeNotHash", + (Consumer) op -> + when(op.getConf().getMode()).thenReturn(Mode.MERGEPARTIAL)), + Arguments.of("basicStatsIncomplete", + (Consumer) op -> + when(op.getStatistics().getBasicStatsState()).thenReturn(State.PARTIAL)), + Arguments.of("columnStatsIncomplete", + (Consumer) op -> + when(op.getStatistics().getColumnStatsState()).thenReturn(State.PARTIAL)) + ); } /** @@ -137,8 +125,7 @@ void testProcessReturnsNullWhenColumnStatsIncomplete() throws SemanticException * so the colStats loop is a no-op (the inputs to computeNDVGroupingColumns are * controlled directly via mockStatic in each test). */ - private static GroupByOperator setupCompleteHashGroupBy( - float defaultMinReduction, float defaultMinReductionLowerBound) { + private static GroupByOperator setupCompleteHashGroupBy() { GroupByOperator op = mock(GroupByOperator.class); GroupByDesc desc = mock(GroupByDesc.class); Statistics stats = mock(Statistics.class); @@ -153,8 +140,8 @@ private static GroupByOperator setupCompleteHashGroupBy( when(desc.getMode()).thenReturn(Mode.HASH); when(desc.getKeys()).thenReturn(Collections.emptyList()); - when(desc.getMinReductionHashAggr()).thenReturn(defaultMinReduction); - when(desc.getMinReductionHashAggrLowerBound()).thenReturn(defaultMinReductionLowerBound); + when(desc.getMinReductionHashAggr()).thenReturn(DEFAULT_MIN_REDUCTION); + when(desc.getMinReductionHashAggrLowerBound()).thenReturn(DEFAULT_MIN_REDUCTION_LOWER_BOUND); when(stats.getBasicStatsState()).thenReturn(State.COMPLETE); when(stats.getColumnStatsState()).thenReturn(State.COMPLETE); diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSortedDynPartitionOptimizer.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSortedDynPartitionOptimizer.java index 64d12657a452..b92f502e12d1 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSortedDynPartitionOptimizer.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSortedDynPartitionOptimizer.java @@ -19,14 +19,12 @@ package org.apache.hadoop.hive.ql.optimizer; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.function.Function; @@ -61,27 +59,18 @@ void testComputePartCardinalityColumnBranch( RowSchema schema = mock(RowSchema.class); when(fsParent.getSchema()).thenReturn(schema); + ColStatistics[] colStats = buildColStats(ndvs, firstStatNull, "p"); List sig = new ArrayList<>(); + List partitionPos = new ArrayList<>(); for (int i = 0; i < ndvs.length; i++) { String colName = "p" + i; ColumnInfo ci = mock(ColumnInfo.class); when(ci.getInternalName()).thenReturn(colName); sig.add(ci); - - if (i == 0 && firstStatNull) { - when(tStats.getColumnStatisticsFromColName(colName)).thenReturn(null); - } else { - ColStatistics cs = new ColStatistics(colName, "int"); - cs.setCountDistint(ndvs[i]); - when(tStats.getColumnStatisticsFromColName(colName)).thenReturn(cs); - } - } - when(schema.getSignature()).thenReturn(sig); - - List partitionPos = new ArrayList<>(); - for (int i = 0; i < ndvs.length; i++) { + when(tStats.getColumnStatisticsFromColName(colName)).thenReturn(colStats[i]); partitionPos.add(i); } + when(schema.getSignature()).thenReturn(sig); long result = proc.computePartCardinality( partitionPos, Collections.emptyList(), tStats, fsParent, new ArrayList<>()); @@ -127,18 +116,12 @@ void testComputePartCardinalityCustomExprBranch( exprs.add(cols -> resolved); } + ColStatistics[] colStats = buildColStats(ndvs, firstStatNull, "e"); try (MockedStatic stub = mockStatic(StatsUtils.class)) { for (int i = 0; i < ndvs.length; i++) { final int idx = i; - final ColStatistics cs; - if (idx == 0 && firstStatNull) { - cs = null; - } else { - cs = new ColStatistics("e" + idx, "int"); - cs.setCountDistint(ndvs[idx]); - } stub.when(() -> StatsUtils.getColStatisticsFromExpression(eq(conf), eq(tStats), eq(resolvedExprs.get(idx)))) - .thenReturn(cs); + .thenReturn(colStats[idx]); } long result = proc.computePartCardinality( @@ -175,4 +158,23 @@ private static SortedDynPartitionOptimizer.SortedDynamicPartitionProc newProc(Pa SortedDynPartitionOptimizer outer = new SortedDynPartitionOptimizer(); return outer.new SortedDynamicPartitionProc(parseCtx); } + + /** + * Builds one ColStatistics per ndvs entry; the first entry is null when firstStatNull + * is true (used to simulate "missing stats" for either the partition-column or the + * custom-expression branch). + */ + private static ColStatistics[] buildColStats(long[] ndvs, boolean firstStatNull, String prefix) { + ColStatistics[] result = new ColStatistics[ndvs.length]; + for (int i = 0; i < ndvs.length; i++) { + if (i == 0 && firstStatNull) { + result[i] = null; + } else { + ColStatistics cs = new ColStatistics(prefix + i, "int"); + cs.setCountDistint(ndvs[i]); + result[i] = cs; + } + } + return result; + } } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestHiveRelMdDistinctRowCount.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestHiveRelMdDistinctRowCount.java index 60cc0cc28d5e..c49b6db33b28 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestHiveRelMdDistinctRowCount.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestHiveRelMdDistinctRowCount.java @@ -24,7 +24,6 @@ import static org.mockito.Mockito.when; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.stream.Stream; @@ -51,8 +50,7 @@ void testGetDistinctRowCountHiveTableScan( HiveRelMdDistinctRowCount provider = new HiveRelMdDistinctRowCount(); Double result = provider.getDistinctRowCount(htRel, mq, ImmutableBitSet.of(0), null); - assertEquals(expected, result, - "getDistinctRowCount for ndvs=" + Arrays.toString(ndvs) + " rowCount=" + rowCount); + assertEquals(expected, result); } private static Stream getDistinctRowCountCases() { diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/stats/annotation/TestStatsRulesProcFactory.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/stats/annotation/TestStatsRulesProcFactory.java index 1bcce8d02a12..7eefc84c4ac8 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/stats/annotation/TestStatsRulesProcFactory.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/stats/annotation/TestStatsRulesProcFactory.java @@ -80,7 +80,7 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; -public class TestStatsRulesProcFactory { +class TestStatsRulesProcFactory { private final static String COL_NAME = "col1"; private final static ExprNodeDesc COL_EXPR = new ExprNodeColumnDesc( @@ -91,7 +91,7 @@ public class TestStatsRulesProcFactory { private final static long[] VALUES = { 1L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 3L, 4L, 5L, 6L, 7L }; @Test - public void testComparisonRowCountZeroNonNullValues() throws SemanticException { + void testComparisonRowCountZeroNonNullValues() throws SemanticException { long numNulls = 2; long[] values = {}; Statistics stats = createStatistics(values, numNulls); @@ -105,7 +105,7 @@ public void testComparisonRowCountZeroNonNullValues() throws SemanticException { } @Test - public void testComparisonRowCountInvalidKll() throws SemanticException { + void testComparisonRowCountInvalidKll() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); stats.getColumnStats().get(0).setHistogram(null); @@ -127,13 +127,17 @@ public void testComparisonRowCountInvalidKll() throws SemanticException { assertEquals((VALUES.length + numNulls) / 3, numRows); } - @Test - public void testEvaluateInExprWithUnknownNDVAppliesHalfFactor() throws SemanticException { - // HIVE-29625: when the column's NDV is unknown (-1), the IN filter takes - // factor *= 0.5 and continues (rather than the old behavior of treating - // dvs==0 as unknown). currNumRows=13, factor=0.5, inFactor=1.0 (default). + /** + * HIVE-29625: IN-filter row-count estimate by NDV of the column. + * unknown NDV (-1) -> factor *= 0.5 per IN value, currNumRows = round(rows * 0.5) + * verified-zero NDV (0) -> factor = 0, no rows match + */ + @ParameterizedTest(name = "{0}") + @MethodSource("evaluateInExprCases") + void testEvaluateInExprByNDV(String name, long ndvOverride, long expectedRows) + throws SemanticException { Statistics stats = createStatistics(VALUES, 0); - stats.getColumnStats().get(0).setCountDistint(-1); // force unknown NDV + stats.getColumnStats().get(0).setCountDistint(ndvOverride); AnnotateStatsProcCtx ctx = spy(new AnnotateStatsProcCtx(null)); when(ctx.getConf()).thenReturn(new HiveConf()); @@ -145,34 +149,28 @@ public void testEvaluateInExprWithUnknownNDVAppliesHalfFactor() throws SemanticE long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( stats, inExpr, ctx, Arrays.asList(COL_NAME), null, VALUES.length); - assertEquals(Math.round(VALUES.length * 0.5), numRows); + assertEquals(expectedRows, numRows); } - @Test - public void testEvaluateEqualWithUnknownNDVUsesHalfRows() throws SemanticException { - // HIVE-29625: col = const where col.NDV=-1 (unknown) falls back to numRows/2. - // VALUES.length=13, expected = 13/2 = 6 (long division). - Statistics stats = createStatistics(VALUES, 0); - stats.getColumnStats().get(0).setCountDistint(-1); - - AnnotateStatsProcCtx ctx = spy(new AnnotateStatsProcCtx(null)); - when(ctx.getConf()).thenReturn(new HiveConf()); - - ExprNodeDesc eqExpr = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo, - new GenericUDFOPEqual(), - Arrays.asList(COL_EXPR, createExprNodeConstantDesc(1))); - - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, eqExpr, ctx, Arrays.asList(COL_NAME), null, VALUES.length); - - assertEquals(6, numRows); + private static Stream evaluateInExprCases() { + return Stream.of( + Arguments.of("unknownNDVAppliesHalfFactor", -1L, Math.round(VALUES.length * 0.5)), + Arguments.of("verifiedZeroReturnsZero", 0L, 0L) + ); } - @Test - public void testEvaluateEqualWithVerifiedZeroNDVReturnsZero() throws SemanticException { - // HIVE-29625: col = const where col.NDV=0 (verified zero) returns 0 rows. + /** + * HIVE-29625: col = const row-count estimate by NDV of the column. + * unknown NDV (-1) -> numRows/2 (13/2 = 6) + * verified-zero NDV (0) -> 0 rows match + * known NDV (n) -> uniform distribution numRows/n (13/7 = 1 ~ rounded to 2) + */ + @ParameterizedTest(name = "{0}") + @MethodSource("evaluateEqualCases") + void testEvaluateEqualByNDV(String name, long ndvOverride, long expectedRows) + throws SemanticException { Statistics stats = createStatistics(VALUES, 0); - stats.getColumnStats().get(0).setCountDistint(0); + stats.getColumnStats().get(0).setCountDistint(ndvOverride); AnnotateStatsProcCtx ctx = spy(new AnnotateStatsProcCtx(null)); when(ctx.getConf()).thenReturn(new HiveConf()); @@ -184,12 +182,20 @@ public void testEvaluateEqualWithVerifiedZeroNDVReturnsZero() throws SemanticExc long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( stats, eqExpr, ctx, Arrays.asList(COL_NAME), null, VALUES.length); - assertEquals(0, numRows); + assertEquals(expectedRows, numRows); + } + + private static Stream evaluateEqualCases() { + return Stream.of( + Arguments.of("unknownNDVUsesHalfRows", -1L, 6L), + Arguments.of("verifiedZeroReturnsZero", 0L, 0L), + Arguments.of("knownNDVUsesUniformDistribution", 7L, 2L) + ); } @ParameterizedTest(name = "{0}") @MethodSource("groupByFinalCases") - public void testGroupByStatsRuleFinalCardinality(String name, long keyNdv, long expectedRows) throws SemanticException { + void testGroupByStatsRuleFinalCardinality(String name, long keyNdv, long expectedRows) throws SemanticException { assertGroupByFinalCardinality(keyNdv, expectedRows); } @@ -203,7 +209,8 @@ private static Stream groupByFinalCases() { @ParameterizedTest(name = "{0}") @MethodSource("groupByHashCases") - public void testCheckMapSideAggregationHashCardinality(String name, long keyNdv, long expectedRows) throws SemanticException { + void testCheckMapSideAggregationHashCardinality(String name, long keyNdv, long expectedRows) + throws SemanticException { assertGroupByHashCardinality(keyNdv, expectedRows); } @@ -324,46 +331,7 @@ private void assertGroupByFinalCardinality(long keyNdv, long expectedRows) throw } @Test - public void testEvaluateEqualWithKnownNDVUsesUniformDistribution() throws SemanticException { - // Regression check: col = const where col.NDV=7 returns round(13/7)=2 rows. - // VALUES has 7 distinct values, so createStatistics sets NDV=7. - Statistics stats = createStatistics(VALUES, 0); - - AnnotateStatsProcCtx ctx = spy(new AnnotateStatsProcCtx(null)); - when(ctx.getConf()).thenReturn(new HiveConf()); - - ExprNodeDesc eqExpr = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo, - new GenericUDFOPEqual(), - Arrays.asList(COL_EXPR, createExprNodeConstantDesc(1))); - - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, eqExpr, ctx, Arrays.asList(COL_NAME), null, VALUES.length); - - assertEquals(2, numRows); - } - - @Test - public void testEvaluateInExprWithVerifiedZeroNDVReturnsZero() throws SemanticException { - // HIVE-29625: when the column's NDV is verified zero (0), the IN filter - // sets factor=0 and breaks out of the loop, so no rows match. - Statistics stats = createStatistics(VALUES, 0); - stats.getColumnStats().get(0).setCountDistint(0); // force verified-zero NDV - - AnnotateStatsProcCtx ctx = spy(new AnnotateStatsProcCtx(null)); - when(ctx.getConf()).thenReturn(new HiveConf()); - - ExprNodeDesc inExpr = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo, - new GenericUDFIn(), - Arrays.asList(COL_EXPR, createExprNodeConstantDesc(1), createExprNodeConstantDesc(2))); - - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, inExpr, ctx, Arrays.asList(COL_NAME), null, VALUES.length); - - assertEquals(0, numRows); - } - - @Test - public void testComparisonRowCountLessThan() throws SemanticException { + void testComparisonRowCountLessThan() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -376,7 +344,7 @@ public void testComparisonRowCountLessThan() throws SemanticException { } @Test - public void testComparisonRowCountLessThanMin() throws SemanticException { + void testComparisonRowCountLessThanMin() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -389,7 +357,7 @@ public void testComparisonRowCountLessThanMin() throws SemanticException { } @Test - public void testComparisonRowCountLessThanBelowMin() throws SemanticException { + void testComparisonRowCountLessThanBelowMin() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -402,7 +370,7 @@ public void testComparisonRowCountLessThanBelowMin() throws SemanticException { } @Test - public void testComparisonRowCountLessThanMax() throws SemanticException { + void testComparisonRowCountLessThanMax() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -415,7 +383,7 @@ public void testComparisonRowCountLessThanMax() throws SemanticException { } @Test - public void testComparisonRowCountLessThanAboveMax() throws SemanticException { + void testComparisonRowCountLessThanAboveMax() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -428,7 +396,7 @@ public void testComparisonRowCountLessThanAboveMax() throws SemanticException { } @Test - public void testComparisonRowCountEqualOrLessThan() throws SemanticException { + void testComparisonRowCountEqualOrLessThan() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -440,7 +408,7 @@ public void testComparisonRowCountEqualOrLessThan() throws SemanticException { } @Test - public void testComparisonRowCountEqualOrLessThanMin() throws SemanticException { + void testComparisonRowCountEqualOrLessThanMin() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -452,7 +420,7 @@ public void testComparisonRowCountEqualOrLessThanMin() throws SemanticException } @Test - public void testComparisonRowCountEqualOrLessThanBelowMin() throws SemanticException { + void testComparisonRowCountEqualOrLessThanBelowMin() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -464,7 +432,7 @@ public void testComparisonRowCountEqualOrLessThanBelowMin() throws SemanticExcep } @Test - public void testComparisonRowCountEqualOrLessThanMax() throws SemanticException { + void testComparisonRowCountEqualOrLessThanMax() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -476,7 +444,7 @@ public void testComparisonRowCountEqualOrLessThanMax() throws SemanticException } @Test - public void testComparisonRowCountEqualOrLessThanAboveMax() throws SemanticException { + void testComparisonRowCountEqualOrLessThanAboveMax() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -488,7 +456,7 @@ public void testComparisonRowCountEqualOrLessThanAboveMax() throws SemanticExcep } @Test - public void testComparisonRowCountGreaterThan() throws SemanticException { + void testComparisonRowCountGreaterThan() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -500,7 +468,7 @@ public void testComparisonRowCountGreaterThan() throws SemanticException { } @Test - public void testComparisonRowCountGreaterThanMin() throws SemanticException { + void testComparisonRowCountGreaterThanMin() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -512,7 +480,7 @@ public void testComparisonRowCountGreaterThanMin() throws SemanticException { } @Test - public void testComparisonRowCountGreaterThanBelowMin() throws SemanticException { + void testComparisonRowCountGreaterThanBelowMin() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -524,7 +492,7 @@ public void testComparisonRowCountGreaterThanBelowMin() throws SemanticException } @Test - public void testComparisonRowCountGreaterThanMax() throws SemanticException { + void testComparisonRowCountGreaterThanMax() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -536,7 +504,7 @@ public void testComparisonRowCountGreaterThanMax() throws SemanticException { } @Test - public void testComparisonRowCountGreaterThanAboveMax() throws SemanticException { + void testComparisonRowCountGreaterThanAboveMax() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -548,7 +516,7 @@ public void testComparisonRowCountGreaterThanAboveMax() throws SemanticException } @Test - public void testComparisonRowCountEqualOrGreaterThan() throws SemanticException { + void testComparisonRowCountEqualOrGreaterThan() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -560,7 +528,7 @@ public void testComparisonRowCountEqualOrGreaterThan() throws SemanticException } @Test - public void testComparisonRowCountEqualOrGreaterThanMin() throws SemanticException { + void testComparisonRowCountEqualOrGreaterThanMin() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -572,7 +540,7 @@ public void testComparisonRowCountEqualOrGreaterThanMin() throws SemanticExcepti } @Test - public void testComparisonRowCountEqualOrGreaterThanBelowMin() throws SemanticException { + void testComparisonRowCountEqualOrGreaterThanBelowMin() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -584,7 +552,7 @@ public void testComparisonRowCountEqualOrGreaterThanBelowMin() throws SemanticEx } @Test - public void testComparisonRowCountEqualOrGreaterThanMax() throws SemanticException { + void testComparisonRowCountEqualOrGreaterThanMax() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -596,7 +564,7 @@ public void testComparisonRowCountEqualOrGreaterThanMax() throws SemanticExcepti } @Test - public void testComparisonRowCountEqualOrGreaterThanBeyondMax() throws SemanticException { + void testComparisonRowCountEqualOrGreaterThanBeyondMax() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, @@ -608,7 +576,7 @@ public void testComparisonRowCountEqualOrGreaterThanBeyondMax() throws SemanticE } @Test - public void testComparisonRowCountEqualOrLessThanWhenMinEqualMax() throws SemanticException { + void testComparisonRowCountEqualOrLessThanWhenMinEqualMax() throws SemanticException { long[] values = { 1L, 1L }; long numNulls = 2; Statistics stats = createStatistics(values, numNulls); @@ -622,7 +590,7 @@ public void testComparisonRowCountEqualOrLessThanWhenMinEqualMax() throws Semant } @Test - public void testComparisonRowCountEqualOrGreaterThanWhenMinEqualMax() throws SemanticException { + void testComparisonRowCountEqualOrGreaterThanWhenMinEqualMax() throws SemanticException { long[] values = { 1L, 1L }; long numNulls = 2; Statistics stats = createStatistics(values, numNulls); @@ -636,7 +604,7 @@ public void testComparisonRowCountEqualOrGreaterThanWhenMinEqualMax() throws Sem } @Test - public void testBetween() throws SemanticException { + void testBetween() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -650,7 +618,7 @@ public void testBetween() throws SemanticException { } @Test - public void testLiteralExtraction() { + void testLiteralExtraction() { final double DELTA = 1e-5; assertEquals((float) 100, @@ -680,7 +648,7 @@ public void testLiteralExtraction() { } @Test - public void testLiteralExtractionFailures() { + void testLiteralExtractionFailures() { // make sure the correct exceptions are raised so that we can default to standard computation String[] types = {"int", "tinyint", "smallint", "bigint", "date", "timestamp", "float", "double"}; for (String type : types) { @@ -697,7 +665,7 @@ public void testLiteralExtractionFailures() { } @Test - public void testBetweenLeftLowerThanMin() throws SemanticException { + void testBetweenLeftLowerThanMin() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -717,7 +685,7 @@ public void testBetweenLeftLowerThanMin() throws SemanticException { } @Test - public void testBetweenLeftLowerThanMinRightHigherThanMax() throws SemanticException { + void testBetweenLeftLowerThanMinRightHigherThanMax() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -731,7 +699,7 @@ public void testBetweenLeftLowerThanMinRightHigherThanMax() throws SemanticExcep } @Test - public void testBetweenRightHigherThanMax() throws SemanticException { + void testBetweenRightHigherThanMax() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -751,7 +719,7 @@ public void testBetweenRightHigherThanMax() throws SemanticException { } @Test - public void testBetweenRightLowerThanMin() throws SemanticException { + void testBetweenRightLowerThanMin() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -765,7 +733,7 @@ public void testBetweenRightLowerThanMin() throws SemanticException { } @Test - public void testBetweenLeftHigherThanMax() throws SemanticException { + void testBetweenLeftHigherThanMax() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -779,7 +747,7 @@ public void testBetweenLeftHigherThanMax() throws SemanticException { } @Test - public void testBetweenLeftEqualMax() throws SemanticException { + void testBetweenLeftEqualMax() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -793,7 +761,7 @@ public void testBetweenLeftEqualMax() throws SemanticException { } @Test - public void testNotBetween() throws SemanticException { + void testNotBetween() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -818,7 +786,7 @@ public void testNotBetween() throws SemanticException { } @Test - public void testNotBetweenLowerThanMinHigherThanMax() throws SemanticException { + void testNotBetweenLowerThanMinHigherThanMax() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -832,7 +800,7 @@ public void testNotBetweenLowerThanMinHigherThanMax() throws SemanticException { } @Test - public void testNotBetweenLeftEqualsRight() throws SemanticException { + void testNotBetweenLeftEqualsRight() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -846,7 +814,7 @@ public void testNotBetweenLeftEqualsRight() throws SemanticException { } @Test - public void testNotBetweenRightLowerThanLeft() throws SemanticException { + void testNotBetweenRightLowerThanLeft() throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); @@ -908,7 +876,7 @@ private static ColStatistics createColStatistics( * Without the fix, valuesCount = numRows - (-1) = numRows + 1 (wrong). */ @Test - public void testComputeAggregateColumnMinMaxWithUnknownNumNulls() throws SemanticException { + void testComputeAggregateColumnMinMaxWithUnknownNumNulls() throws SemanticException { ColStatistics cs = new ColStatistics("_col0", "bigint"); HiveConf conf = new HiveConf(); @@ -943,7 +911,7 @@ public void testComputeAggregateColumnMinMaxWithUnknownNumNulls() throws Semanti } @Test - public void testComputeAggregateColumnMinMaxWithKnownNumNulls() throws SemanticException { + void testComputeAggregateColumnMinMaxWithKnownNumNulls() throws SemanticException { ColStatistics cs = new ColStatistics("_col0", "bigint"); HiveConf conf = new HiveConf(); @@ -974,17 +942,21 @@ public void testComputeAggregateColumnMinMaxWithKnownNumNulls() throws SemanticE "COUNT max should be 80 (numRows - numNulls)"); } - @Test - public void testComputeAggregateColumnMinMaxDistinctWithUnknownNDVReturnsEarly() throws SemanticException { - // HIVE-29625: for COUNT(DISTINCT col), valuesCount = parentCS.getCountDistint(). - // When that NDV is -1 (unknown), the new guard returns early to avoid building - // a Range with a negative maxValue. + /** + * HIVE-29625: COUNT(DISTINCT col) uses parentCS.getCountDistint() as the max range. + * When NDV is unknown (-1) the new guard short-circuits before building a Range with + * negative maxValue. When NDV is known, Range is set to [0, NDV]. + */ + @ParameterizedTest(name = "{0}") + @MethodSource("computeAggregateColumnMinMaxDistinctCases") + void testComputeAggregateColumnMinMaxDistinctByNDV( + String name, long parentNDV, Long expectedMax) throws SemanticException { ColStatistics cs = new ColStatistics("_col0", "bigint"); HiveConf conf = new HiveConf(); ColStatistics parentColStats = new ColStatistics("val", "int"); parentColStats.setNumNulls(0); - parentColStats.setCountDistint(-1); // unknown NDV + parentColStats.setCountDistint(parentNDV); parentColStats.setRange(1, 100); Statistics parentStats = new Statistics(100, 400, 400, 400); @@ -1001,38 +973,21 @@ public void testComputeAggregateColumnMinMaxDistinctWithUnknownNDVReturnsEarly() StatsRulesProcFactory.GroupByStatsRule.computeAggregateColumnMinMax( cs, conf, agg, "bigint", parentStats); - assertNull(cs.getRange(), "Range should NOT be set when DISTINCT NDV is unknown"); + if (expectedMax == null) { + assertNull(cs.getRange(), "Range should NOT be set when DISTINCT NDV is unknown"); + } else { + assertNotNull(cs.getRange(), "Range should be set when DISTINCT NDV is known"); + assertEquals(0L, ((Number) cs.getRange().minValue).longValue()); + assertEquals(expectedMax.longValue(), ((Number) cs.getRange().maxValue).longValue(), + "COUNT DISTINCT max should equal the parent NDV"); + } } - @Test - public void testComputeAggregateColumnMinMaxDistinctWithKnownNDVSetsRange() throws SemanticException { - // Regression: COUNT(DISTINCT col) with known parentCS.NDV=50 sets Range(0, 50). - ColStatistics cs = new ColStatistics("_col0", "bigint"); - HiveConf conf = new HiveConf(); - - ColStatistics parentColStats = new ColStatistics("val", "int"); - parentColStats.setNumNulls(0); - parentColStats.setCountDistint(50); - parentColStats.setRange(1, 100); - - Statistics parentStats = new Statistics(100, 400, 400, 400); - parentStats.addToColumnStats(Collections.singletonList(parentColStats)); - - ExprNodeColumnDesc colExpr = new ExprNodeColumnDesc( - TypeInfoFactory.intTypeInfo, "val", "t", false); - AggregationDesc agg = new AggregationDesc(); - agg.setGenericUDAFName("count"); - agg.setParameters(Collections.singletonList(colExpr)); - agg.setDistinct(true); - agg.setMode(GenericUDAFEvaluator.Mode.COMPLETE); - - StatsRulesProcFactory.GroupByStatsRule.computeAggregateColumnMinMax( - cs, conf, agg, "bigint", parentStats); - - assertNotNull(cs.getRange(), "Range should be set when DISTINCT NDV is known"); - assertEquals(0L, ((Number) cs.getRange().minValue).longValue()); - assertEquals(50L, ((Number) cs.getRange().maxValue).longValue(), - "COUNT DISTINCT max should equal the NDV (50)"); + private static Stream computeAggregateColumnMinMaxDistinctCases() { + return Stream.of( + Arguments.of("unknownNDVReturnsEarlyNoRange", -1L, null), + Arguments.of("knownNDVSetsRangeUpToNDV", 50L, 50L) + ); } /** @@ -1041,7 +996,7 @@ public void testComputeAggregateColumnMinMaxDistinctWithKnownNDVSetsRange() thro * Without the fix, LEFT_OUTER_JOIN would calculate: newNumNulls = oldNumNulls + leftUnmatchedRows = -1 + 100 = 99 */ @Test - public void testUpdateNumNullsPreservesUnknownNumNulls() { + void testUpdateNumNullsPreservesUnknownNumNulls() { StatsRulesProcFactory.JoinStatsRule joinStatsRule = new StatsRulesProcFactory.JoinStatsRule(); // Create ColStatistics with numNulls = -1 (unknown) @@ -1076,7 +1031,7 @@ public void testUpdateNumNullsPreservesUnknownNumNulls() { @ParameterizedTest(name = "{0}") @MethodSource("calculateUnmatchedRowsForOuterCases") - public void testCalculateUnmatchedRowsForOuter( + void testCalculateUnmatchedRowsForOuter( String name, long ndv, long distinctUnmatched, long expected) { assertCalculateUnmatchedRowsForOuter(ndv, distinctUnmatched, expected); } @@ -1093,7 +1048,7 @@ private static Stream calculateUnmatchedRowsForOuterCases() { @ParameterizedTest(name = "{0}") @MethodSource("computeRowCountAssumingInnerJoinCases") - public void testComputeRowCountAssumingInnerJoin(String name, long denom, long expected) { + void testComputeRowCountAssumingInnerJoin(String name, long denom, long expected) { assertComputeRowCountAssumingInnerJoin(denom, expected); } @@ -1107,7 +1062,7 @@ private static Stream computeRowCountAssumingInnerJoinCases() { @ParameterizedTest(name = "{0}") @MethodSource("updateColStatsCases") - public void testUpdateColStats(String name, long initialNdv, long expectedNdv) { + void testUpdateColStats(String name, long initialNdv, long expectedNdv) { ColStatistics cs = new ColStatistics("k", "int"); cs.setCountDistint(initialNdv); cs.setNumNulls(0); From 03960fd8b60fd5acc71633589189d911bc00cb4c Mon Sep 17 00:00:00 2001 From: Konstantin Bereznyakov Date: Thu, 28 May 2026 08:08:45 -0700 Subject: [PATCH 5/5] HIVE-29625: more SQ feedback --- .../annotation/StatsRulesProcFactory.java | 22 +- .../optimizer/TestReduceSinkMapJoinProc.java | 17 +- .../TestSetHashGroupByMinReduction.java | 17 +- .../TestSortedDynPartitionOptimizer.java | 6 +- .../annotation/TestStatsRulesProcFactory.java | 256 ++++-------------- 5 files changed, 67 insertions(+), 251 deletions(-) diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java index fe65494e0dfe..28bc2623a6a6 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/stats/annotation/StatsRulesProcFactory.java @@ -559,23 +559,23 @@ private long evaluateInExpr(Statistics stats, ExprNodeDesc pred, long currNumRow for (int i = 0; i < columnStats.size(); i++) { ColStatistics cs = columnStats.get(i); long dvs = cs == null ? -1L : cs.getCountDistint(); - if (dvs < 0) { - // missing stats or unknown NDV - factor *= 0.5; - continue; - } if (dvs == 0) { // verified zero distinct values: IN cannot match any row factor = 0; break; } - // (num of distinct vals for col in IN clause / num of distinct vals for col ) - double columnFactor = 1.0 / dvs; - if (!multiColumn) { - columnFactor *= estimateIntersectionSize(aspCtx.getConf(), columnStats.get(i), values.get(i)); + if (dvs < 0) { + // missing stats or unknown NDV + factor *= 0.5; + } else { + // (num of distinct vals for col in IN clause / num of distinct vals for col ) + double columnFactor = 1.0 / dvs; + if (!multiColumn) { + columnFactor *= estimateIntersectionSize(aspCtx.getConf(), columnStats.get(i), values.get(i)); + } + // max can be 1, even when ndv is larger in IN clause than in column stats + factor *= Math.min(columnFactor, 1.0); } - // max can be 1, even when ndv is larger in IN clause than in column stats - factor *= Math.min(columnFactor, 1.0); } // Clamp at 1 to be sure that we don't get out of range. diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestReduceSinkMapJoinProc.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestReduceSinkMapJoinProc.java index 75feeff971ff..58329a228027 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestReduceSinkMapJoinProc.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestReduceSinkMapJoinProc.java @@ -59,15 +59,7 @@ class TestReduceSinkMapJoinProc { - /** - * Exercises the keyCount-from-NDV branch in processReduceSinkToHashJoin (HIVE-29625). - * Master used `cs.getCountDistint() <= 0` to fall back to MAX_VALUE; HIVE-29625 uses - * `cs.getCountDistint() < 0`, so verified-zero NDV no longer cascades to "no info" - * but falls through to `maxKeyCount *= 0` (then clamped to 1 by later logic). - * - * A null `ndv` row represents StatsUtils.getColStatisticsFromExpression returning null - * (no derivable stat) - shares the same MAX_VALUE fallback as NDV < 0. - */ + // A null ndv row represents StatsUtils.getColStatisticsFromExpression returning null. @ParameterizedTest(name = "{0}") @MethodSource("keyCountFromNdvCases") void testProcessReduceSinkToHashJoinKeyCountFromNdv( @@ -98,12 +90,7 @@ private static Stream keyCountFromNdvCases() { ); } - /** - * Shared invocation harness: build a real GenTezProcContext + mocked operator chain, - * stub StatsUtils.getColStatisticsFromExpression to return the given colStat for - * the single key column, run processReduceSinkToHashJoin, and assert the keyCount - * landed in joinConf.getParentKeyCounts() at position 0. - */ + // Shared harness: build GenTezProcContext + mocked operators, run the method, read the keyCount put(). private static void invokeAndAssertKeyCount( ColStatistics csForKey, long parentRows, long expectedKeyCount) throws Exception { diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSetHashGroupByMinReduction.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSetHashGroupByMinReduction.java index 3d1bdf2f71b2..6c7c2145382b 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSetHashGroupByMinReduction.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSetHashGroupByMinReduction.java @@ -56,12 +56,6 @@ class TestSetHashGroupByMinReduction { private static final float DEFAULT_MIN_REDUCTION = 0.99f; private static final float DEFAULT_MIN_REDUCTION_LOWER_BOUND = 0.1f; - /** - * NDV product drives the central HIVE-29625 disambiguation: - * ndvProduct < 0 -> unknown -> early-return null (no setMinReductionHashAggr call) - * ndvProduct == 0 -> verified zero -> factor = 1.0, NOT less than default 0.99 (no call) - * ndvProduct > 0 -> compute factor = 1 - ndvProduct/numRows, set if < default - */ @ParameterizedTest(name = "{0}") @MethodSource("ndvProductCases") void testProcessByNdvProduct(String name, long ndvProduct, boolean expectSetCall) @@ -88,11 +82,6 @@ private static Stream ndvProductCases() { ); } - /** - * Each gate is one of the early-return conditions in process(): non-HASH mode, - * incomplete basic stats, or incomplete column stats. All three early-return without - * touching setMinReductionHashAggr. - */ @ParameterizedTest(name = "{0}") @MethodSource("earlyReturnGateCases") void testProcessEarlyReturnsOnUnsupportedState(String name, Consumer flipGate) @@ -120,11 +109,7 @@ private static Stream earlyReturnGateCases() { ); } - /** - * Build a GroupByOperator that passes all the early-return gates, with empty keys - * so the colStats loop is a no-op (the inputs to computeNDVGroupingColumns are - * controlled directly via mockStatic in each test). - */ + // Passes all early-return gates; empty keys make the colStats loop a no-op. private static GroupByOperator setupCompleteHashGroupBy() { GroupByOperator op = mock(GroupByOperator.class); GroupByDesc desc = mock(GroupByDesc.class); diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSortedDynPartitionOptimizer.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSortedDynPartitionOptimizer.java index b92f502e12d1..3fbd063619e3 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSortedDynPartitionOptimizer.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/TestSortedDynPartitionOptimizer.java @@ -159,11 +159,7 @@ private static SortedDynPartitionOptimizer.SortedDynamicPartitionProc newProc(Pa return outer.new SortedDynamicPartitionProc(parseCtx); } - /** - * Builds one ColStatistics per ndvs entry; the first entry is null when firstStatNull - * is true (used to simulate "missing stats" for either the partition-column or the - * custom-expression branch). - */ + // First entry is null when firstStatNull is true; simulates a missing stat for either branch. private static ColStatistics[] buildColStats(long[] ndvs, boolean firstStatNull, String prefix) { ColStatistics[] result = new ColStatistics[ndvs.length]; for (int i = 0; i < ndvs.length; i++) { diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/stats/annotation/TestStatsRulesProcFactory.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/stats/annotation/TestStatsRulesProcFactory.java index 7eefc84c4ac8..6e646a6ad502 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/stats/annotation/TestStatsRulesProcFactory.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/stats/annotation/TestStatsRulesProcFactory.java @@ -330,249 +330,97 @@ private void assertGroupByFinalCardinality(long keyNdv, long expectedRows) throw assertEquals(expectedRows, captor.getValue().getNumRows()); } - @Test - void testComparisonRowCountLessThan() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPLessThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(3))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(8, numRows); - } - - @Test - void testComparisonRowCountLessThanMin() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPLessThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(1))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(0, numRows); - } - - @Test - void testComparisonRowCountLessThanBelowMin() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPLessThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(0))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(0, numRows); - } - - @Test - void testComparisonRowCountLessThanMax() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPLessThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(7))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(12, numRows); - } - - @Test - void testComparisonRowCountLessThanAboveMax() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPLessThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(8))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(13, numRows); - } - - @Test - void testComparisonRowCountEqualOrLessThan() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPEqualOrLessThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(3))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(9, numRows); - } - - @Test - void testComparisonRowCountEqualOrLessThanMin() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPEqualOrLessThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(1))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(1, numRows); - } - - @Test - void testComparisonRowCountEqualOrLessThanBelowMin() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPEqualOrLessThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(0))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(0, numRows); - } - - @Test - void testComparisonRowCountEqualOrLessThanMax() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPEqualOrLessThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(7))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(13, numRows); - } - - @Test - void testComparisonRowCountEqualOrLessThanAboveMax() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPEqualOrLessThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(8))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(13, numRows); - } - - @Test - void testComparisonRowCountGreaterThan() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPGreaterThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(5))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(2, numRows); - } - - @Test - void testComparisonRowCountGreaterThanMin() throws SemanticException { + @ParameterizedTest(name = "{0}") + @MethodSource("comparisonRowCountLessThanCases") + void testComparisonRowCountLessThan(String name, int constant, long expected) throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPGreaterThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(1))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - assertEquals(12, numRows); - } - - @Test - void testComparisonRowCountGreaterThanBelowMin() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPGreaterThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(0))); + new GenericUDFOPLessThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(constant))); long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - assertEquals(13, numRows); + assertEquals(expected, numRows); } - @Test - void testComparisonRowCountGreaterThanMax() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPGreaterThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(7))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(0, numRows); + private static Stream comparisonRowCountLessThanCases() { + return Stream.of( + Arguments.of("midRange", 3, 8L), + Arguments.of("equalToMin", 1, 0L), + Arguments.of("belowMin", 0, 0L), + Arguments.of("equalToMax", 7, 12L), + Arguments.of("aboveMax", 8, 13L) + ); } - @Test - void testComparisonRowCountGreaterThanAboveMax() throws SemanticException { + @ParameterizedTest(name = "{0}") + @MethodSource("comparisonRowCountEqualOrLessThanCases") + void testComparisonRowCountEqualOrLessThan(String name, int constant, long expected) throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPGreaterThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(8))); + new GenericUDFOPEqualOrLessThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(constant))); long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - assertEquals(0, numRows); + assertEquals(expected, numRows); } - @Test - void testComparisonRowCountEqualOrGreaterThan() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPEqualOrGreaterThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(5))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(3, numRows); + private static Stream comparisonRowCountEqualOrLessThanCases() { + return Stream.of( + Arguments.of("midRange", 3, 9L), + Arguments.of("equalToMin", 1, 1L), + Arguments.of("belowMin", 0, 0L), + Arguments.of("equalToMax", 7, 13L), + Arguments.of("aboveMax", 8, 13L) + ); } - @Test - void testComparisonRowCountEqualOrGreaterThanMin() throws SemanticException { + @ParameterizedTest(name = "{0}") + @MethodSource("comparisonRowCountGreaterThanCases") + void testComparisonRowCountGreaterThan(String name, int constant, long expected) throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPEqualOrGreaterThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(1))); + new GenericUDFOPGreaterThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(constant))); long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - assertEquals(13, numRows); + assertEquals(expected, numRows); } - @Test - void testComparisonRowCountEqualOrGreaterThanBelowMin() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPEqualOrGreaterThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(0))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(13, numRows); + private static Stream comparisonRowCountGreaterThanCases() { + return Stream.of( + Arguments.of("midRange", 5, 2L), + Arguments.of("equalToMin", 1, 12L), + Arguments.of("belowMin", 0, 13L), + Arguments.of("equalToMax", 7, 0L), + Arguments.of("aboveMax", 8, 0L) + ); } - @Test - void testComparisonRowCountEqualOrGreaterThanMax() throws SemanticException { + @ParameterizedTest(name = "{0}") + @MethodSource("comparisonRowCountEqualOrGreaterThanCases") + void testComparisonRowCountEqualOrGreaterThan(String name, int constant, long expected) throws SemanticException { long numNulls = 2; Statistics stats = createStatistics(VALUES, numNulls); ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPEqualOrGreaterThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(7))); + new GenericUDFOPEqualOrGreaterThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(constant))); long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - assertEquals(1, numRows); + assertEquals(expected, numRows); } - @Test - void testComparisonRowCountEqualOrGreaterThanBeyondMax() throws SemanticException { - long numNulls = 2; - Statistics stats = createStatistics(VALUES, numNulls); - ExprNodeDesc exprNodeDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.intTypeInfo, - new GenericUDFOPEqualOrGreaterThan(), Arrays.asList(COL_EXPR, createExprNodeConstantDesc(8))); - long numRows = new StatsRulesProcFactory.FilterStatsRule().evaluateExpression( - stats, exprNodeDesc, STATS_PROC_CTX, Collections.emptyList(), null, VALUES.length + numNulls); - - assertEquals(0, numRows); + private static Stream comparisonRowCountEqualOrGreaterThanCases() { + return Stream.of( + Arguments.of("midRange", 5, 3L), + Arguments.of("equalToMin", 1, 13L), + Arguments.of("belowMin", 0, 13L), + Arguments.of("equalToMax", 7, 1L), + Arguments.of("aboveMax", 8, 0L) + ); } @Test