Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.gluten.vectorized;

import org.apache.gluten.execution.BroadCastHashJoinContext;
import org.apache.gluten.execution.BroadcastJoinContext;
import org.apache.gluten.execution.JoinTypeTransform;
import org.apache.gluten.expression.ConverterUtils$;
import org.apache.gluten.utils.SubstraitUtil;
Expand All @@ -31,16 +31,17 @@

public class StorageJoinBuilder {

public static native void nativeCleanBuildHashTable(String hashTableId, long hashTableData);
public static native void nativeCleanBuildHashTable(int hashTableId, long hashTableData);

public static native long nativeCloneBuildHashTable(long hashTableData);

private static native long nativeBuild(
String buildHashTableId,
int buildTableId,
byte[] in,
long rowCount,
String joinKeys,
int joinType,
boolean isBhj,
boolean hasMixedFiltCondition,
boolean isExistenceJoin,
byte[] namedStruct,
Expand All @@ -53,16 +54,16 @@ private StorageJoinBuilder() {}
public static long build(
byte[] batches,
long rowCount,
BroadCastHashJoinContext broadCastContext,
BroadcastJoinContext broadcastContext,
List<Expression> newBuildKeys,
List<Attribute> newOutput,
boolean hasNullKeyValues) {
ConverterUtils$ converter = ConverterUtils$.MODULE$;
List<Expression> keys;
List<Attribute> output;
if (newBuildKeys.isEmpty()) {
keys = JavaConverters.<Expression>seqAsJavaList(broadCastContext.buildSideJoinKeys());
output = JavaConverters.<Attribute>seqAsJavaList(broadCastContext.buildSideStructure());
keys = JavaConverters.<Expression>seqAsJavaList(broadcastContext.buildSideJoinKeys());
output = JavaConverters.<Attribute>seqAsJavaList(broadcastContext.buildSideStructure());
} else {
keys = newBuildKeys;
output = newOutput;
Expand All @@ -77,24 +78,25 @@ public static long build(
.collect(Collectors.joining(","));

int joinType;
if (broadCastContext.buildHashTableId().startsWith("BuiltBNLJBroadcastTable-")) {
joinType = SubstraitUtil.toCrossRelSubstrait(broadCastContext.joinType()).ordinal();
} else {
boolean buildRight = broadCastContext.buildRight();
if (broadcastContext.isBhj()) {
boolean buildRight = broadcastContext.buildRight();
joinType =
JoinTypeTransform.toSubstraitJoinType(broadCastContext.joinType(), buildRight).ordinal();
JoinTypeTransform.toSubstraitJoinType(broadcastContext.joinType(), buildRight).ordinal();
} else {
joinType = SubstraitUtil.toCrossRelSubstrait(broadcastContext.joinType()).ordinal();
}

return nativeBuild(
broadCastContext.buildHashTableId(),
broadcastContext.buildTableId(),
batches,
rowCount,
joinKey,
joinType,
broadCastContext.hasMixedFiltCondition(),
broadCastContext.isExistenceJoin(),
broadcastContext.isBhj(),
broadcastContext.hasMixedFiltCondition(),
broadcastContext.isExistenceJoin(),
SubstraitUtil.toNameStruct(output).toByteArray(),
broadCastContext.isNullAwareAntiJoin(),
broadcastContext.isNullAwareAntiJoin(),
hasNullKeyValues);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ case class BroadcastHashTable(pointer: Long, relation: ClickHouseBuildSideRelati
* The complicated part is due to reuse exchange, where multiple BHJ IDs correspond to a
* `ClickHouseBuildSideRelation`.
*/
object CHBroadcastBuildSideCache extends Logging with RemovalListener[String, BroadcastHashTable] {
object CHBroadcastBuildSideCache extends Logging with RemovalListener[Int, BroadcastHashTable] {

private lazy val expiredTime = SparkEnv.get.conf.getLong(
CHBackendSettings.GLUTEN_CLICKHOUSE_BROADCAST_CACHE_EXPIRED_TIME,
Expand All @@ -45,37 +45,37 @@ object CHBroadcastBuildSideCache extends Logging with RemovalListener[String, Br

// Use for controlling to build bhj hash table once.
// key: hashtable id, value is hashtable backend pointer(long to string).
private val buildSideRelationCache: Cache[String, BroadcastHashTable] =
private val buildSideRelationCache: Cache[Int, BroadcastHashTable] =
Caffeine.newBuilder
.expireAfterAccess(expiredTime, TimeUnit.SECONDS)
.removalListener(this)
.build[String, BroadcastHashTable]()
.build[Int, BroadcastHashTable]()

def getOrBuildBroadcastHashTable(
broadcast: Broadcast[BuildSideRelation],
broadCastContext: BroadCastHashJoinContext): BroadcastHashTable = {
broadcastContext: BroadcastJoinContext): BroadcastHashTable = {

buildSideRelationCache
.get(
broadCastContext.buildHashTableId,
(broadcast_id: String) => {
broadcastContext.buildTableId,
(broadcastId: String) => {
val (pointer, relation) =
broadcast.value
.asInstanceOf[ClickHouseBuildSideRelation]
.buildHashTable(broadCastContext)
logDebug(s"Create bhj $broadcast_id = 0x${pointer.toHexString}")
.buildHashTable(broadcastContext)
logDebug(s"Create bhj $broadcastId = 0x${pointer.toHexString}")
BroadcastHashTable(pointer, relation)
}
)
}

/** This is callback from c++ backend. */
def get(broadcastHashtableId: String): Long =
/** This is called from c++ side. */
def get(broadcastHashtableId: Int): Long =
Option(buildSideRelationCache.getIfPresent(broadcastHashtableId))
.map(_.pointer)
.getOrElse(0)

def invalidateBroadcastHashtable(broadcastHashtableId: String): Unit = {
def invalidateBroadcastHashtable(broadcastHashtableId: Int): Unit = {
// Cleanup operations on the backend are idempotent.
buildSideRelationCache.invalidate(broadcastHashtableId)
}
Expand All @@ -85,7 +85,7 @@ object CHBroadcastBuildSideCache extends Logging with RemovalListener[String, Br

def cleanAll(): Unit = buildSideRelationCache.invalidateAll()

override def onRemoval(key: String, value: BroadcastHashTable, cause: RemovalCause): Unit = {
override def onRemoval(key: Int, value: BroadcastHashTable, cause: RemovalCause): Unit = {
logDebug(s"Remove bhj $key = 0x${value.pointer.toHexString}")
if (value.relation != null) {
value.relation.reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,21 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
GlutenDriverEndpoint.collectResources(executionId, buildBroadcastTableId)
} else {
logWarning(
s"Can't not trace broadcast table data $buildBroadcastTableId" +
s"Can not trace broadcast table data $buildBroadcastTableId" +
s" because execution id is null." +
s" Will clean up until expire time.")
}
val broadcast = buildPlan.executeBroadcast[BuildSideRelation]()
val context =
BroadCastHashJoinContext(
BroadcastJoinContext(
Seq.empty,
finalJoinType,
buildSide == BuildRight,
false,
joinType.isInstanceOf[ExistenceJoin],
buildPlan.output,
buildBroadcastTableId)
buildBroadcastTableId,
false)
val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, context)
streamedRDD :+ broadcastRDD
}
Expand All @@ -99,9 +100,9 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
val joinParametersStr = new StringBuffer("JoinParameters:")
joinParametersStr
.append("isBHJ=")
.append(1)
.append(0)
.append("\n")
.append("buildHashTableId=")
.append("buildBroadcastTableId=")
.append(buildBroadcastTableId)
.append("\n")
val message = StringValue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ case class CHShuffledHashJoinExecTransformer(
case class CHBroadcastBuildSideRDD(
@transient private val sc: SparkContext,
broadcasted: broadcast.Broadcast[BuildSideRelation],
broadcastContext: BroadCastHashJoinContext)
broadcastContext: BroadcastJoinContext)
extends BroadcastBuildSideRDD(sc, broadcasted) {

override def genBroadcastBuildSideIterator(): Iterator[ColumnarBatch] = {
Expand All @@ -198,14 +198,15 @@ case class CHBroadcastBuildSideRDD(
}
}

case class BroadCastHashJoinContext(
case class BroadcastJoinContext(
buildSideJoinKeys: Seq[Expression],
joinType: JoinType,
buildRight: Boolean,
hasMixedFiltCondition: Boolean,
isExistenceJoin: Boolean,
buildSideStructure: Seq[Attribute],
buildHashTableId: String,
buildTableId: Int,
isBhj: Boolean,
isNullAwareAntiJoin: Boolean = false)

case class CHBroadcastHashJoinExecTransformer(
Expand Down Expand Up @@ -259,14 +260,15 @@ case class CHBroadcastHashJoinExecTransformer(
}
val broadcast = buildPlan.executeBroadcast[BuildSideRelation]()
val context =
BroadCastHashJoinContext(
BroadcastJoinContext(
buildKeyExprs,
joinType,
buildSide == BuildRight,
isMixedCondition(condition),
joinType.isInstanceOf[ExistenceJoin],
buildPlan.output,
buildHashTableId,
true,
isNullAwareAntiJoin
)
val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution.joins

import org.apache.gluten.execution.{BroadCastHashJoinContext, ColumnarNativeIterator}
import org.apache.gluten.execution.{BroadcastJoinContext, ColumnarNativeIterator}
import org.apache.gluten.utils.{IteratorUtil, PlanNodesUtil}
import org.apache.gluten.vectorized._

Expand Down Expand Up @@ -46,18 +46,17 @@ case class ClickHouseBuildSideRelation(

private var hashTableData: Long = 0L

def buildHashTable(
broadCastContext: BroadCastHashJoinContext): (Long, ClickHouseBuildSideRelation) =
def buildHashTable(broadcastContext: BroadcastJoinContext): (Long, ClickHouseBuildSideRelation) =
synchronized {
if (hashTableData == 0) {
logDebug(
s"BHJ value size: " +
s"${broadCastContext.buildHashTableId} = ${batches.length}")
s"${broadcastContext.buildTableId} = ${batches.length}")
// Build the hash table
hashTableData = StorageJoinBuilder.build(
batches,
numOfRows,
broadCastContext,
broadcastContext,
newBuildKeys.asJava,
output.asJava,
hasNullKeyValues)
Expand All @@ -79,7 +78,7 @@ case class ClickHouseBuildSideRelation(
override def transform(key: Expression): Array[InternalRow] = {
// native block reader
val blockReader = new CHStreamReader(CHShuffleReadStreamFactory.create(batches, true))
val broadCastIter: Iterator[ColumnarBatch] = IteratorUtil.createBatchIterator(blockReader)
val broadcastIter: Iterator[ColumnarBatch] = IteratorUtil.createBatchIterator(blockReader)

val transformProjections = mode match {
case HashedRelationBroadcastMode(k, _) => k
Expand All @@ -88,7 +87,7 @@ case class ClickHouseBuildSideRelation(

// Expression compute, return block iterator
val expressionEval = new SimpleExpressionEval(
new ColumnarNativeIterator(broadCastIter.asJava),
new ColumnarNativeIterator(broadcastIter.asJava),
PlanNodesUtil.genProjectionsPlanNode(transformProjections, output))

val proj = UnsafeProjection.create(Seq(key))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution.benchmarks

import org.apache.gluten.execution.BroadCastHashJoinContext
import org.apache.gluten.execution.BroadcastJoinContext
import org.apache.gluten.vectorized.StorageJoinBuilder

import org.apache.spark.benchmark.Benchmark
Expand Down Expand Up @@ -87,15 +87,15 @@ object CHHashBuildBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchmark w
new util.ArrayList[Expression](),
new util.ArrayList[Attribute](),
false)
StorageJoinBuilder.nativeCleanBuildHashTable("", table)
StorageJoinBuilder.nativeCleanBuildHashTable(0, table)
}
}
}
benchmark.run()
}

private def createBroadcastRelation(
child: SparkPlan): (Array[Byte], Long, BroadCastHashJoinContext) = {
child: SparkPlan): (Array[Byte], Long, BroadcastJoinContext) = {
val dataSize = SQLMetrics.createSizeMetric(spark.sparkContext, "size of files read")

val countsAndBytes = child
Expand All @@ -105,7 +105,7 @@ object CHHashBuildBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchmark w
(
countsAndBytes.flatMap(_._2),
countsAndBytes.map(_._1).sum,
BroadCastHashJoinContext(Seq(child.output.head), Inner, true, false, false, child.output, "")
BroadcastJoinContext(Seq(child.output.head), Inner, true, false, false, child.output, 0, true)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ object CHStorageJoinBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchmark

def iterateBatch(array: Array[Byte], compressed: Boolean): Int = {
val blockReader = new CHStreamReader(CHShuffleReadStreamFactory.create(array, compressed))
val broadCastIter: Iterator[ColumnarBatch] = IteratorUtil.createBatchIterator(blockReader)
broadCastIter.foldLeft(0) {
val broadcastIter: Iterator[ColumnarBatch] = IteratorUtil.createBatchIterator(blockReader)
broadcastIter.foldLeft(0) {
case (acc, batch) =>
acc + batch.numRows
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ case class ShuffledHashJoinExecTransformer(
newLeft: SparkPlan,
newRight: SparkPlan): ShuffledHashJoinExecTransformer =
copy(left = newLeft, right = newRight)

override def genJoinParametersInternal(): (Int, Int, Int) = {
(0, 0, buildPlan.id)
}
}

case class BroadcastHashJoinExecTransformer(
Expand Down
Loading