Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,14 @@ class AggregateOpDesc extends LogicalOp {
val inputSchema = inputSchemas(operatorInfo.inputPorts.head.id)
val outputSchema = Schema(
groupByKeys.map(key => inputSchema.getAttribute(key)) ++
localAggregations.map(agg =>
agg.getAggregationAttribute(inputSchema.getAttribute(agg.attribute).getType)
)
localAggregations.map { agg =>
// COUNT(*) ignores the attribute entirely, so never look up an input
// column for it (a stale/leaked attribute may not exist in the schema).
val attrType =
if (agg.aggFunction == AggregationFunction.COUNT_STAR) null
else inputSchema.getAttribute(agg.attribute).getType
agg.getAggregationAttribute(attrType)
}
)
Map(PortIdentity(internal = true) -> outputSchema)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,15 @@ class AggregateOpExec(descString: String) extends OperatorExecutor {

// Initialize distributedAggregations if it's not yet initialized
if (distributedAggregations == null) {
distributedAggregations = desc.aggregations.map(agg =>
agg.getAggFunc(tuple.getSchema.getAttribute(agg.attribute).getType)
)
distributedAggregations = desc.aggregations.map { agg =>
// COUNT(*) ignores the attribute entirely, so never look up an input column for
// it (a stale/leaked attribute may not exist in the schema). Its result type
// does not depend on any input attribute.
val attrType =
if (agg.aggFunction == AggregationFunction.COUNT_STAR) null
else tuple.getSchema.getAttribute(agg.attribute).getType
agg.getAggFunc(attrType)
}
}

// Construct the group key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public enum AggregationFunction {

COUNT("count"),

COUNT_STAR("count(*)"),
Comment thread
tanishqgandhi1908 marked this conversation as resolved.

AVERAGE("average"),

MIN("min"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,23 @@ case class AveragePartialObj(sum: Double, count: Double) extends Serializable {}
}
]
}
}
},
"allOf": [
{
"if": {
"properties": {
"aggFunction": { "const": "count(*)" }
}
},
"then": {},
"else": {
"required": ["attribute"],
"properties": {
"attribute": { "minLength": 1 }
}
}
}
]
}
""")
class AggregationOperation {
Expand All @@ -64,8 +80,8 @@ class AggregationOperation {
@JsonPropertyDescription("sum, count, average, min, max, or concat")
var aggFunction: AggregationFunction = _

@JsonProperty(value = "attribute", required = true)
@JsonPropertyDescription("column to calculate average value")
@JsonProperty(value = "attribute")
@JsonPropertyDescription("column to aggregate on")
Comment thread
tanishqgandhi1908 marked this conversation as resolved.
@AutofillAttributeName
var attribute: String = _

Expand All @@ -77,26 +93,28 @@ class AggregationOperation {
@JsonIgnore
def getAggregationAttribute(attrType: AttributeType): Attribute = {
val resultAttrType = this.aggFunction match {
case AggregationFunction.SUM => attrType
case AggregationFunction.COUNT => AttributeType.INTEGER
case AggregationFunction.AVERAGE => AttributeType.DOUBLE
case AggregationFunction.MIN => attrType
case AggregationFunction.MAX => attrType
case AggregationFunction.CONCAT => AttributeType.STRING
case _ => throw new RuntimeException("Unknown aggregation function: " + this.aggFunction)
case AggregationFunction.SUM => attrType
case AggregationFunction.COUNT => AttributeType.INTEGER
case AggregationFunction.COUNT_STAR => AttributeType.INTEGER
case AggregationFunction.AVERAGE => AttributeType.DOUBLE
case AggregationFunction.MIN => attrType
case AggregationFunction.MAX => attrType
case AggregationFunction.CONCAT => AttributeType.STRING
case _ => throw new RuntimeException("Unknown aggregation function: " + this.aggFunction)
}
new Attribute(resultAttribute, resultAttrType)
}

@JsonIgnore
def getAggFunc(attrType: AttributeType): DistributedAggregation[Object] = {
val aggFunc = aggFunction match {
case AggregationFunction.AVERAGE => averageAgg()
case AggregationFunction.COUNT => countAgg()
case AggregationFunction.MAX => maxAgg(attrType)
case AggregationFunction.MIN => minAgg(attrType)
case AggregationFunction.SUM => sumAgg(attrType)
case AggregationFunction.CONCAT => concatAgg()
case AggregationFunction.AVERAGE => averageAgg()
case AggregationFunction.COUNT => countAgg(countAllRows = false)
case AggregationFunction.COUNT_STAR => countAgg(countAllRows = true)
case AggregationFunction.MAX => maxAgg(attrType)
case AggregationFunction.MIN => minAgg(attrType)
case AggregationFunction.SUM => sumAgg(attrType)
case AggregationFunction.CONCAT => concatAgg()
case _ =>
throw new UnsupportedOperationException("Unknown aggregation function: " + aggFunction)
}
Expand All @@ -106,8 +124,10 @@ class AggregationOperation {
@JsonIgnore
def getFinal: AggregationOperation = {
val newAggFunc = aggFunction match {
case AggregationFunction.COUNT => AggregationFunction.SUM
case a: AggregationFunction => a
// Both COUNT variants emit partial counts locally; the global stage sums them.
case AggregationFunction.COUNT => AggregationFunction.SUM
case AggregationFunction.COUNT_STAR => AggregationFunction.SUM
case a: AggregationFunction => a
}
val res = new AggregationOperation()
res.aggFunction = newAggFunc
Expand Down Expand Up @@ -138,16 +158,12 @@ class AggregationOperation {
)
}

private def countAgg(): DistributedAggregation[Integer] = {
private def countAgg(countAllRows: Boolean): DistributedAggregation[Integer] = {
// COUNT(*) counts every row; COUNT(column) counts only rows with a non-null attribute.
new DistributedAggregation[Integer](
() => 0,
(partial, tuple) => {
val inc =
if (attribute == null) 1
else if (tuple.getField(attribute) != null) 1
else 0
partial + inc
},
(partial, tuple) =>
partial + (if (countAllRows || tuple.getField(attribute) != null) 1 else 0),
(partial1, partial2) => partial1 + partial2,
partial => partial
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,13 @@ class AggregateOpDescSpec extends AnyFlatSpec with Matchers {
.getExternalOutputSchemas(Map(PortIdentity() -> input)) shouldBe
Map(PortIdentity() -> Schema().add("avg", AttributeType.DOUBLE))
}

it should "type a COUNT(*) result as INTEGER without looking up an input column" in {
// COUNT(*) ignores its attribute; even a stale/leaked attribute that does not exist
// in the input schema must not be dereferenced during schema propagation.
val input = Schema().add("v", AttributeType.LONG)
descWith(List.empty, aggOp(AggregationFunction.COUNT_STAR, "ghost", "row_count"))
.getExternalOutputSchemas(Map(PortIdentity() -> input)) shouldBe
Map(PortIdentity() -> Schema().add("row_count", AttributeType.INTEGER))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ class AggregateOpSpec extends AnyFunSuite {
assert(attr.getType == AttributeType.INTEGER)
}

test("getAggregationAttribute maps COUNT(*) result to INTEGER even with a null input type") {
// COUNT(*) has no input column, so schema propagation passes a null attrType;
// it must still resolve to INTEGER without dereferencing it.
val operation = makeAggregationOp(AggregationFunction.COUNT_STAR, "", "row_count")
val attr = operation.getAggregationAttribute(null)

assert(attr.getName == "row_count")
assert(attr.getType == AttributeType.INTEGER)
}

test("getAggregationAttribute maps CONCAT result type to STRING") {
val operation = makeAggregationOp(AggregationFunction.CONCAT, "tag", "all_tags")
val attr = operation.getAggregationAttribute(AttributeType.INTEGER)
Expand Down Expand Up @@ -142,13 +152,33 @@ class AggregateOpSpec extends AnyFunSuite {
assert(math.abs(result - 4.0) < 1e-6)
}

test("COUNT aggregation with attribute == null counts all rows") {
test("COUNT(*) aggregation counts all rows regardless of nulls") {
// COUNT(*) hides the attribute in the UI, so it arrives with a blank attribute.
val schema = makeSchema("points" -> AttributeType.INTEGER)
val tuple1 = makeTuple(schema, 10)
val tuple2 = makeTuple(schema, null)
val tuple3 = makeTuple(schema, 20)

val operation = makeAggregationOp(AggregationFunction.COUNT, null, "row_count")
val operation = makeAggregationOp(AggregationFunction.COUNT_STAR, "", "row_count")
val agg = operation.getAggFunc(AttributeType.INTEGER)

var partial = agg.init()
partial = agg.iterate(partial, tuple1)
partial = agg.iterate(partial, tuple2)
partial = agg.iterate(partial, tuple3)

val result = agg.finalAgg(partial).asInstanceOf[Number].intValue()
assert(result == 3)
}

test("COUNT(*) aggregation ignores any attribute value and still counts every row") {
// Even if an attribute leaks through, COUNT(*) must count all rows (incl. null cells).
val schema = makeSchema("points" -> AttributeType.INTEGER)
val tuple1 = makeTuple(schema, 10)
val tuple2 = makeTuple(schema, null)
val tuple3 = makeTuple(schema, 20)

val operation = makeAggregationOp(AggregationFunction.COUNT_STAR, "points", "row_count")
val agg = operation.getAggFunc(AttributeType.INTEGER)

var partial = agg.init()
Expand Down Expand Up @@ -447,6 +477,15 @@ class AggregateOpSpec extends AnyFunSuite {
assert(finalOp.resultAttribute == "price_count")
}

test("getFinal rewrites COUNT(*) into SUM over the intermediate result attribute") {
val operation = makeAggregationOp(AggregationFunction.COUNT_STAR, "", "row_count")
val finalOp = operation.getFinal

assert(finalOp.aggFunction == AggregationFunction.SUM)
assert(finalOp.attribute == "row_count")
assert(finalOp.resultAttribute == "row_count")
}

test("getFinal keeps non-COUNT aggregation function and rewires attribute to resultAttribute") {
val operation = makeAggregationOp(AggregationFunction.SUM, "amount", "total_amount")
val finalOp = operation.getFinal
Expand Down Expand Up @@ -535,4 +574,33 @@ class AggregateOpSpec extends AnyFunSuite {
assert(totalRevenue == 350)
assert(rowCount == 3)
}

test("AggregateOpExec computes COUNT(*) over every row (including nulls) end-to-end") {
// region (ignored), revenue (one null). COUNT(*) ignores its attribute, so even a
// stale attribute that is absent from the schema must not be looked up; the
// executor still counts all 3 rows.
val schema = makeSchema(
"region" -> AttributeType.STRING,
"revenue" -> AttributeType.INTEGER
)

val tuple1 = makeTuple(schema, "west", 100)
val tuple2 = makeTuple(schema, "east", null)
val tuple3 = makeTuple(schema, "west", 50)

val desc = new AggregateOpDesc()
desc.aggregations =
List(makeAggregationOp(AggregationFunction.COUNT_STAR, "ghost", "row_count"))
desc.groupByKeys = List() // global aggregation

val exec = new AggregateOpExec(objectMapper.writeValueAsString(desc))
exec.open()
exec.processTuple(tuple1, 0)
exec.processTuple(tuple2, 0)
exec.processTuple(tuple3, 0)

val results = exec.onFinish(0).toList
assert(results.size == 1)
assert(results.head.getFields(0).asInstanceOf[Number].intValue() == 3)
}
}
8 changes: 5 additions & 3 deletions docs/reference/operators/data-cleaning/aggregate/aggregate.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ tags: [data-cleaning, aggregate]
| Property | Requirement | Type | Default | Description |
|----------|-------------|------|---------|-------------|
| Aggregations | ✓ | List<Aggregation> | - | Multiple aggregation functions (min: 1,<br>aggregations cannot be empty) |
| ↳ Aggregate Func | ✓ | sum, count, average, min, max, concat | - | Sum, count, average, min, max, or concat |
| ↳ Attribute | ✓ | String | - | Column to calculate average value |
| ↳ Result Attribute | ✓ | String | - | Column name of average result |
| ↳ Aggregate Func | ✓ | sum, count, count(*), average, min, max, concat | - | Sum, count, count(*), average, min, max, or concat |
| ↳ Attribute | ✓ (hidden for `count(*)`) | String | - | Column to aggregate on. Required for every function except `count(*)`, which counts all rows and hides this field |
| ↳ Result Attribute | ✓ | String | - | Column name of the aggregation result |
| Group By Keys | | List | - | Group by columns |

> **Counting rows**: use `count(*)` to count every row (including rows with nulls) without selecting a column. Use `count` with a column to count only that column's non-null values.

### Output Ports

| Port | Mode |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ import { map, switchMap, take } from "rxjs/operators";

Quill.register("modules/cursors", QuillCursors);

// The Aggregate "count(*)" function: counts all rows and takes no attribute.
export const COUNT_STAR = "count(*)";

/**
* Property Editor uses JSON Schema to automatically generate the form from the JSON Schema of an operator.
* For example, the JSON Schema of Sentiment Analysis could be:
Expand Down Expand Up @@ -545,6 +548,17 @@ export class OperatorPropertyEditFrameComponent implements OnInit, OnChanges, On
mappedField.type = "datasetversionselector";
}

// Aggregate: the attribute is required for every function except count(*), which counts all rows.
// For count(*) the attribute is irrelevant, so hide the field and drop the required marker.
// Both react to the sibling aggFunction within the same row.
if (this.currentOperatorSchema?.operatorType === "Aggregate" && mappedField.key === "attribute") {
mappedField.expressions = {
...mappedField.expressions,
"props.required": (field: FormlyFieldConfig) => field.parent?.model?.aggFunction !== COUNT_STAR,
hide: (field: FormlyFieldConfig) => field.parent?.model?.aggFunction === COUNT_STAR,
};
}

if (this.currentOperatorSchema?.operatorType === "FileScanOp" && mappedField.key === "outputFileName") {
mappedField.expressions = {
...mappedField.expressions,
Expand Down
Loading