From 1404adac81be1e84eda05c8d59464e62dd3d9f76 Mon Sep 17 00:00:00 2001 From: Aitozi Date: Tue, 10 Mar 2026 21:25:44 +0800 Subject: [PATCH] [core] Introduce hive bucket function to align with spark hive bucket --- .../generated/core_configuration.html | 2 +- .../java/org/apache/paimon/CoreOptions.java | 7 +- .../apache/paimon/bucket/BucketFunction.java | 2 + .../paimon/bucket/HiveBucketFunction.java | 98 +++++++++++++++ .../org/apache/paimon/bucket/HiveHasher.java | 114 ++++++++++++++++++ .../paimon/bucket/HiveBucketFunctionTest.java | 105 ++++++++++++++++ .../apache/paimon/hash/HiveHasherTest.java | 77 ++++++++++++ .../catalog/functions/PaimonFunctions.scala | 26 ++-- .../catalog/functions/BucketFunctionTest.java | 90 ++++++++++++++ 9 files changed, 506 insertions(+), 15 deletions(-) create mode 100644 paimon-core/src/main/java/org/apache/paimon/bucket/HiveBucketFunction.java create mode 100644 paimon-core/src/main/java/org/apache/paimon/bucket/HiveHasher.java create mode 100644 paimon-core/src/test/java/org/apache/paimon/bucket/HiveBucketFunctionTest.java create mode 100644 paimon-core/src/test/java/org/apache/paimon/hash/HiveHasherTest.java diff --git a/docs/layouts/shortcodes/generated/core_configuration.html b/docs/layouts/shortcodes/generated/core_configuration.html index 5003d23dacc8..673d988dabf7 100644 --- a/docs/layouts/shortcodes/generated/core_configuration.html +++ b/docs/layouts/shortcodes/generated/core_configuration.html @@ -108,7 +108,7 @@
bucket-function.type
default

Enum

- The bucket function for paimon bucket.

Possible values: + The bucket function for paimon bucket.

Possible values:
bucket-key
diff --git a/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java b/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java index 70e50efe9c83..d5934b65dbfb 100644 --- a/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java +++ b/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java @@ -152,7 +152,10 @@ public enum BucketFunctionType implements DescribedEnum { MOD( "mod", "The modulus bucket function which will use modulus arithmetic: bucket_id = Math.floorMod(bucket_key_value, numBuckets) to get bucket. " - + "Note: the bucket key must be a single field of INT or BIGINT datatype."); + + "Note: the bucket key must be a single field of INT or BIGINT datatype."), + HIVE( + "hive", + "The hive bucket function which will use hive-compatible hash arithmetic to get bucket."); private final String value; private final String description; @@ -172,6 +175,8 @@ public static BucketFunctionType of(String bucketType) { return DEFAULT; } else if (MOD.value.equalsIgnoreCase(bucketType)) { return MOD; + } else if (HIVE.value.equalsIgnoreCase(bucketType)) { + return HIVE; } throw new IllegalArgumentException( "cannot match type: " + bucketType + " for bucket function"); diff --git a/paimon-core/src/main/java/org/apache/paimon/bucket/BucketFunction.java b/paimon-core/src/main/java/org/apache/paimon/bucket/BucketFunction.java index f54d17d7646b..6d5149dd6f58 100644 --- a/paimon-core/src/main/java/org/apache/paimon/bucket/BucketFunction.java +++ b/paimon-core/src/main/java/org/apache/paimon/bucket/BucketFunction.java @@ -41,6 +41,8 @@ static BucketFunction create( return new DefaultBucketFunction(); case MOD: return new ModBucketFunction(bucketKeyType); + case HIVE: + return new HiveBucketFunction(bucketKeyType); default: throw new IllegalArgumentException( "Unsupported bucket type: " + bucketFunctionType); diff --git a/paimon-core/src/main/java/org/apache/paimon/bucket/HiveBucketFunction.java b/paimon-core/src/main/java/org/apache/paimon/bucket/HiveBucketFunction.java new file mode 100644 index 000000000000..68f401bba4be --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/bucket/HiveBucketFunction.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.bucket; + +import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.Decimal; +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.types.RowKind; +import org.apache.paimon.types.RowType; +import org.apache.paimon.utils.InternalRowUtils; + +/** Hive-compatible bucket function. */ +public class HiveBucketFunction implements BucketFunction { + + private static final long serialVersionUID = 1L; + + private static final int SEED = 0; + + private final InternalRow.FieldGetter[] fieldGetters; + + public HiveBucketFunction(RowType rowType) { + this.fieldGetters = InternalRowUtils.createFieldGetters(rowType.getFieldTypes()); + } + + @Override + public int bucket(BinaryRow row, int numBuckets) { + assert numBuckets > 0 && row.getRowKind() == RowKind.INSERT : "Num bucket is illegal"; + + int hash = SEED; + for (int i = 0; i < row.getFieldCount(); i++) { + hash = (31 * hash) + computeHash(fieldGetters[i].getFieldOrNull(row)); + } + return mod(hash & Integer.MAX_VALUE, numBuckets); + } + + static int mod(int value, int divisor) { + int remainder = value % divisor; + if (remainder < 0) { + return (remainder + divisor) % divisor; + } + return remainder; + } + + private int computeHash(Object value) { + if (value == null) { + return 0; + } + + if (value instanceof Boolean) { + return HiveHasher.hashInt((Boolean) value ? 1 : 0); + } else if (value instanceof Byte) { + return HiveHasher.hashInt(((Byte) value).intValue()); + } else if (value instanceof Short) { + return HiveHasher.hashInt(((Short) value).intValue()); + } else if (value instanceof Integer) { + return HiveHasher.hashInt((Integer) value); + } else if (value instanceof Long) { + return HiveHasher.hashLong((Long) value); + } else if (value instanceof Float) { + float floatValue = (Float) value; + return HiveHasher.hashInt(floatValue == -0.0f ? 0 : Float.floatToIntBits(floatValue)); + } else if (value instanceof Double) { + double doubleValue = (Double) value; + return HiveHasher.hashLong( + doubleValue == -0.0d ? 0L : Double.doubleToLongBits(doubleValue)); + } else if (value instanceof BinaryString) { + BinaryString stringValue = (BinaryString) value; + return HiveHasher.hashUnsafeBytes( + stringValue.getSegments(), + stringValue.getOffset(), + stringValue.getSizeInBytes()); + } else if (value instanceof byte[]) { + return HiveHasher.hashBytes((byte[]) value); + } else if (value instanceof Decimal) { + return HiveHasher.normalizeDecimal(((Decimal) value).toBigDecimal()).hashCode(); + } + + throw new UnsupportedOperationException( + "Unsupported type as bucket key type " + value.getClass()); + } +} diff --git a/paimon-core/src/main/java/org/apache/paimon/bucket/HiveHasher.java b/paimon-core/src/main/java/org/apache/paimon/bucket/HiveHasher.java new file mode 100644 index 000000000000..7103f4d9b77f --- /dev/null +++ b/paimon-core/src/main/java/org/apache/paimon/bucket/HiveHasher.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.bucket; + +import org.apache.paimon.memory.MemorySegment; + +import java.math.BigDecimal; +import java.math.RoundingMode; + +/** Hive hash util. */ +public class HiveHasher { + + private static final int HIVE_DECIMAL_MAX_PRECISION = 38; + private static final int HIVE_DECIMAL_MAX_SCALE = 38; + + @Override + public String toString() { + return HiveHasher.class.getSimpleName(); + } + + public static int hashInt(int input) { + return input; + } + + public static int hashLong(long input) { + return Long.hashCode(input); + } + + public static int hashBytes(byte[] bytes) { + int result = 0; + for (byte value : bytes) { + result = (result * 31) + value; + } + return result; + } + + public static int hashUnsafeBytes(MemorySegment[] segments, int offset, int length) { + int result = 0; + for (MemorySegment segment : segments) { + int remaining = segment.size() - offset; + if (remaining > 0) { + int bytesToRead = Math.min(remaining, length); + for (int i = 0; i < bytesToRead; i++) { + result = (result * 31) + segment.get(offset + i); + } + length -= bytesToRead; + offset = 0; + } else { + offset -= segment.size(); + } + + if (length == 0) { + break; + } + } + return result; + } + + public static BigDecimal normalizeDecimal(BigDecimal input) { + if (input == null) { + return null; + } + + BigDecimal result = trimDecimal(input); + int intDigits = result.precision() - result.scale(); + if (intDigits > HIVE_DECIMAL_MAX_PRECISION) { + return null; + } + + int maxScale = + Math.min( + HIVE_DECIMAL_MAX_SCALE, + Math.min(HIVE_DECIMAL_MAX_PRECISION - intDigits, result.scale())); + if (result.scale() > maxScale) { + result = result.setScale(maxScale, RoundingMode.HALF_UP); + result = trimDecimal(result); + } + + return result; + } + + private static BigDecimal trimDecimal(BigDecimal input) { + if (input.compareTo(BigDecimal.ZERO) == 0) { + return BigDecimal.ZERO; + } + + BigDecimal result = input.stripTrailingZeros(); + if (result.compareTo(BigDecimal.ZERO) == 0) { + return BigDecimal.ZERO; + } + + if (result.scale() < 0) { + result = result.setScale(0); + } + + return result; + } +} diff --git a/paimon-core/src/test/java/org/apache/paimon/bucket/HiveBucketFunctionTest.java b/paimon-core/src/test/java/org/apache/paimon/bucket/HiveBucketFunctionTest.java new file mode 100644 index 000000000000..fb5ecc61b46a --- /dev/null +++ b/paimon-core/src/test/java/org/apache/paimon/bucket/HiveBucketFunctionTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.bucket; + +import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.Decimal; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.data.serializer.InternalRowSerializer; +import org.apache.paimon.types.DataTypes; +import org.apache.paimon.types.RowType; + +import org.junit.jupiter.api.Test; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Test for {@link HiveBucketFunction}. */ +class HiveBucketFunctionTest { + + @Test + void testHiveBucketFunction() { + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.STRING(), + DataTypes.BYTES(), + DataTypes.DECIMAL(10, 4)); + HiveBucketFunction hiveBucketFunction = new HiveBucketFunction(rowType); + + BinaryRow row = + toBinaryRow( + rowType, + 7, + BinaryString.fromString("hello"), + new byte[] {1, 2, 3}, + Decimal.fromBigDecimal(new BigDecimal("12.3400"), 10, 4)); + + int expectedHash = + 31 + * (31 + * (31 * 7 + + HiveHasher.hashBytes( + "hello" + .getBytes( + StandardCharsets + .UTF_8))) + + HiveHasher.hashBytes(new byte[] {1, 2, 3})) + + new BigDecimal("12.34").hashCode(); + assertThat(hiveBucketFunction.bucket(row, 8)) + .isEqualTo((expectedHash & Integer.MAX_VALUE) % 8); + } + + @Test + void testHiveBucketFunctionWithNulls() { + RowType rowType = RowType.of(DataTypes.INT(), DataTypes.STRING()); + HiveBucketFunction hiveBucketFunction = new HiveBucketFunction(rowType); + + BinaryRow row = toBinaryRow(rowType, null, null); + + assertThat(hiveBucketFunction.bucket(row, 4)).isZero(); + } + + @Test + void testHiveBucketFunctionUnsupportedType() { + RowType rowType = RowType.of(DataTypes.TIMESTAMP()); + HiveBucketFunction hiveBucketFunction = new HiveBucketFunction(rowType); + + assertThat(hiveBucketFunction.bucket(toBinaryRow(rowType, (Object) null), 4)).isZero(); + + assertThatThrownBy( + () -> + hiveBucketFunction.bucket( + toBinaryRow( + rowType, + org.apache.paimon.data.Timestamp.fromEpochMillis( + 1L)), + 4)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("Unsupported type as bucket key type"); + } + + private BinaryRow toBinaryRow(RowType rowType, Object... values) { + return new InternalRowSerializer(rowType).toBinaryRow(GenericRow.of(values)); + } +} diff --git a/paimon-core/src/test/java/org/apache/paimon/hash/HiveHasherTest.java b/paimon-core/src/test/java/org/apache/paimon/hash/HiveHasherTest.java new file mode 100644 index 000000000000..dac6d97dda5c --- /dev/null +++ b/paimon-core/src/test/java/org/apache/paimon/hash/HiveHasherTest.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.hash; + +import org.apache.paimon.bucket.HiveHasher; +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.memory.MemorySegment; + +import org.junit.jupiter.api.Test; + +import java.util.Random; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Test for {@link HiveHasher}. */ +class HiveHasherTest { + + @Test + void testHashUnsafeBytes() { + BinaryString binaryString = BinaryString.fromString("hello"); + assertThat( + HiveHasher.hashUnsafeBytes( + binaryString.getSegments(), + binaryString.getOffset(), + binaryString.getSizeInBytes())) + .isEqualTo(HiveHasher.hashBytes(binaryString.toBytes())); + + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < 1000; i++) { + builder.append(UUID.randomUUID()); + } + + binaryString = fromString(builder.toString()); + assertThat( + HiveHasher.hashUnsafeBytes( + binaryString.getSegments(), + binaryString.getOffset(), + binaryString.getSizeInBytes())) + .isEqualTo(HiveHasher.hashBytes(binaryString.toBytes())); + } + + private BinaryString fromString(String input) { + BinaryString binaryString = BinaryString.fromString(input); + int numBytes = binaryString.getSizeInBytes(); + int pad = new Random().nextInt(100); + int totalBytes = numBytes + pad; + int segmentSize = totalBytes / 2 + 1; + byte[] bytes1 = new byte[segmentSize]; + byte[] bytes2 = new byte[segmentSize]; + if (segmentSize - pad > 0 && numBytes >= segmentSize - pad) { + binaryString.getSegments()[0].get(0, bytes1, pad, segmentSize - pad); + } + binaryString.getSegments()[0].get( + segmentSize - pad, bytes2, 0, numBytes - segmentSize + pad); + return BinaryString.fromAddress( + new MemorySegment[] {MemorySegment.wrap(bytes1), MemorySegment.wrap(bytes2)}, + pad, + numBytes); + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalog/functions/PaimonFunctions.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalog/functions/PaimonFunctions.scala index acdd40e9b244..dd039de6cc2d 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalog/functions/PaimonFunctions.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalog/functions/PaimonFunctions.scala @@ -43,29 +43,29 @@ object PaimonFunctions { val PAIMON_BUCKET: String = "bucket" val MOD_BUCKET: String = "mod_bucket" + val HIVE_BUCKET: String = "hive_bucket" val MAX_PT: String = "max_pt" val PATH_TO_DESCRIPTOR: String = "path_to_descriptor" val DESCRIPTOR_TO_STRING: String = "descriptor_to_string" - private val FUNCTIONS = ImmutableMap.of( - PAIMON_BUCKET, - new BucketFunction(PAIMON_BUCKET, BucketFunctionType.DEFAULT), - MOD_BUCKET, - new BucketFunction(MOD_BUCKET, BucketFunctionType.MOD), - MAX_PT, - new MaxPtFunction, - PATH_TO_DESCRIPTOR, - new PathToDescriptorUnbound, - DESCRIPTOR_TO_STRING, - new DescriptorToStringUnbound - ) + private val FUNCTIONS = ImmutableMap + .builder[String, UnboundFunction]() + .put(PAIMON_BUCKET, new BucketFunction(PAIMON_BUCKET, BucketFunctionType.DEFAULT)) + .put(MOD_BUCKET, new BucketFunction(MOD_BUCKET, BucketFunctionType.MOD)) + .put(HIVE_BUCKET, new BucketFunction(HIVE_BUCKET, BucketFunctionType.HIVE)) + .put(MAX_PT, new MaxPtFunction) + .put(PATH_TO_DESCRIPTOR, new PathToDescriptorUnbound) + .put(DESCRIPTOR_TO_STRING, new DescriptorToStringUnbound) + .build() /** The bucket function type to the function name mapping */ private val TYPE_FUNC_MAPPING = ImmutableMap.of( BucketFunctionType.DEFAULT, PAIMON_BUCKET, BucketFunctionType.MOD, - MOD_BUCKET + MOD_BUCKET, + BucketFunctionType.HIVE, + HIVE_BUCKET ) val names: ImmutableSet[String] = FUNCTIONS.keySet diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/catalog/functions/BucketFunctionTest.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/catalog/functions/BucketFunctionTest.java index 214965bf15a2..43425486c84d 100644 --- a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/catalog/functions/BucketFunctionTest.java +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/catalog/functions/BucketFunctionTest.java @@ -54,7 +54,14 @@ import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.HiveHash; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; @@ -64,6 +71,8 @@ import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -72,6 +81,8 @@ import java.util.UUID; import java.util.concurrent.ThreadLocalRandom; +import scala.collection.JavaConverters; + /** Tests for Spark bucket functions. */ public class BucketFunctionTest { private static final int NUM_BUCKETS = @@ -243,6 +254,69 @@ private static void validateSparkBucketFunction(String... bucketColumns) { .forEach(row -> Assertions.assertThat(row.getInt(2)).isEqualTo(row.get(1))); } + @Test + public void testHiveBucketFunctionMatchesSparkHiveHash() { + RowType hiveBucketRowType = + new RowType( + Arrays.asList( + new DataField(0, BOOLEAN_COL, new BooleanType()), + new DataField(1, BYTE_COL, new TinyIntType()), + new DataField(2, SHORT_COL, new SmallIntType()), + new DataField(3, INTEGER_COL, new IntType()), + new DataField(4, LONG_COL, new BigIntType()), + new DataField(5, FLOAT_COL, new FloatType()), + new DataField(6, DOUBLE_COL, new DoubleType()), + new DataField( + 7, STRING_COL, new VarCharType(VarCharType.MAX_LENGTH)), + new DataField( + 8, + DECIMAL_COL, + new DecimalType(DECIMAL_PRECISION, DECIMAL_SCALE)), + new DataField( + 9, + COMPACTED_DECIMAL_COL, + new DecimalType( + COMPACTED_DECIMAL_PRECISION, + COMPACTED_DECIMAL_SCALE)), + new DataField( + 10, + BINARY_COL, + new VarBinaryType(VarBinaryType.MAX_LENGTH)))); + StructType schema = + org.apache.paimon.spark.SparkTypeUtils.fromPaimonRowType(hiveBucketRowType); + String[] bucketColumns = hiveBucketRowType.getFieldNames().toArray(new String[0]); + List rows = + Arrays.asList( + RowFactory.create( + true, + (byte) 1, + (short) 2, + 3, + 4L, + 1.5f, + 2.5d, + "hello", + new BigDecimal("12.340000000000000000"), + new BigDecimal("56.789000000"), + "spark-hive".getBytes(StandardCharsets.UTF_8)), + RowFactory.create( + null, null, null, null, null, null, null, null, null, null, null)); + + List result = + spark.createDataFrame(rows, schema) + .selectExpr( + "*", + String.format( + "paimon.sys.hive_bucket(%s, %s) as actual_bucket", + NUM_BUCKETS, String.join(", ", bucketColumns))) + .collectAsList(); + + for (Row row : result) { + Assertions.assertThat((int) row.getAs("actual_bucket")) + .isEqualTo(sparkHiveBucket(row, schema, bucketColumns)); + } + } + @Test public void testBooleanType() { validateSparkBucketFunction(BOOLEAN_COL); @@ -335,4 +409,20 @@ public void testTimestampPrecisionNotEqualToSpark() { .collectAsList() .forEach(row -> Assertions.assertThat(row.getInt(2)).isNotEqualTo(row.get(1))); } + + private static int sparkHiveBucket(Row row, StructType schema, String... bucketColumns) { + List expressions = new ArrayList<>(); + for (String bucketColumn : bucketColumns) { + int index = schema.fieldIndex(bucketColumn); + StructField field = schema.fields()[index]; + expressions.add( + Literal.create(row.isNullAt(index) ? null : row.get(index), field.dataType())); + } + + int hash = + (Integer) + new HiveHash(JavaConverters.asScalaBuffer(expressions).toSeq()) + .eval((org.apache.spark.sql.catalyst.InternalRow) null); + return (hash & Integer.MAX_VALUE) % NUM_BUCKETS; + } }