diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregateOpDesc.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregateOpDesc.scala index 7e76b3ce7e0..ddc897a9e76 100644 --- a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregateOpDesc.scala +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregateOpDesc.scala @@ -78,9 +78,14 @@ class AggregateOpDesc extends LogicalOp { val inputSchema = inputSchemas(operatorInfo.inputPorts.head.id) val outputSchema = Schema( groupByKeys.map(key => inputSchema.getAttribute(key)) ++ - localAggregations.map(agg => - agg.getAggregationAttribute(inputSchema.getAttribute(agg.attribute).getType) - ) + localAggregations.map { agg => + // COUNT(*) ignores the attribute entirely, so never look up an input + // column for it (a stale/leaked attribute may not exist in the schema). + val attrType = + if (agg.aggFunction == AggregationFunction.COUNT_STAR) null + else inputSchema.getAttribute(agg.attribute).getType + agg.getAggregationAttribute(attrType) + } ) Map(PortIdentity(internal = true) -> outputSchema) }) diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregateOpExec.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregateOpExec.scala index a61703b28cf..1c75999ae39 100644 --- a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregateOpExec.scala +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregateOpExec.scala @@ -47,9 +47,15 @@ class AggregateOpExec(descString: String) extends OperatorExecutor { // Initialize distributedAggregations if it's not yet initialized if (distributedAggregations == null) { - distributedAggregations = desc.aggregations.map(agg => - agg.getAggFunc(tuple.getSchema.getAttribute(agg.attribute).getType) - ) + distributedAggregations = desc.aggregations.map { agg => + // COUNT(*) ignores the attribute entirely, so never look up an input column for + // it (a stale/leaked attribute may not exist in the schema). Its result type + // does not depend on any input attribute. + val attrType = + if (agg.aggFunction == AggregationFunction.COUNT_STAR) null + else tuple.getSchema.getAttribute(agg.attribute).getType + agg.getAggFunc(attrType) + } } // Construct the group key diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregationFunction.java b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregationFunction.java index 253d6548524..96b2a09469a 100644 --- a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregationFunction.java +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregationFunction.java @@ -27,6 +27,8 @@ public enum AggregationFunction { COUNT("count"), + COUNT_STAR("count(*)"), + AVERAGE("average"), MIN("min"), diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregationOperation.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregationOperation.scala index 70105de9ef4..b84a46384b9 100644 --- a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregationOperation.scala +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aggregate/AggregationOperation.scala @@ -55,7 +55,23 @@ case class AveragePartialObj(sum: Double, count: Double) extends Serializable {} } ] } - } + }, + "allOf": [ + { + "if": { + "properties": { + "aggFunction": { "const": "count(*)" } + } + }, + "then": {}, + "else": { + "required": ["attribute"], + "properties": { + "attribute": { "minLength": 1 } + } + } + } + ] } """) class AggregationOperation { @@ -64,8 +80,8 @@ class AggregationOperation { @JsonPropertyDescription("sum, count, average, min, max, or concat") var aggFunction: AggregationFunction = _ - @JsonProperty(value = "attribute", required = true) - @JsonPropertyDescription("column to calculate average value") + @JsonProperty(value = "attribute") + @JsonPropertyDescription("column to aggregate on") @AutofillAttributeName var attribute: String = _ @@ -77,13 +93,14 @@ class AggregationOperation { @JsonIgnore def getAggregationAttribute(attrType: AttributeType): Attribute = { val resultAttrType = this.aggFunction match { - case AggregationFunction.SUM => attrType - case AggregationFunction.COUNT => AttributeType.INTEGER - case AggregationFunction.AVERAGE => AttributeType.DOUBLE - case AggregationFunction.MIN => attrType - case AggregationFunction.MAX => attrType - case AggregationFunction.CONCAT => AttributeType.STRING - case _ => throw new RuntimeException("Unknown aggregation function: " + this.aggFunction) + case AggregationFunction.SUM => attrType + case AggregationFunction.COUNT => AttributeType.INTEGER + case AggregationFunction.COUNT_STAR => AttributeType.INTEGER + case AggregationFunction.AVERAGE => AttributeType.DOUBLE + case AggregationFunction.MIN => attrType + case AggregationFunction.MAX => attrType + case AggregationFunction.CONCAT => AttributeType.STRING + case _ => throw new RuntimeException("Unknown aggregation function: " + this.aggFunction) } new Attribute(resultAttribute, resultAttrType) } @@ -91,12 +108,13 @@ class AggregationOperation { @JsonIgnore def getAggFunc(attrType: AttributeType): DistributedAggregation[Object] = { val aggFunc = aggFunction match { - case AggregationFunction.AVERAGE => averageAgg() - case AggregationFunction.COUNT => countAgg() - case AggregationFunction.MAX => maxAgg(attrType) - case AggregationFunction.MIN => minAgg(attrType) - case AggregationFunction.SUM => sumAgg(attrType) - case AggregationFunction.CONCAT => concatAgg() + case AggregationFunction.AVERAGE => averageAgg() + case AggregationFunction.COUNT => countAgg(countAllRows = false) + case AggregationFunction.COUNT_STAR => countAgg(countAllRows = true) + case AggregationFunction.MAX => maxAgg(attrType) + case AggregationFunction.MIN => minAgg(attrType) + case AggregationFunction.SUM => sumAgg(attrType) + case AggregationFunction.CONCAT => concatAgg() case _ => throw new UnsupportedOperationException("Unknown aggregation function: " + aggFunction) } @@ -106,8 +124,10 @@ class AggregationOperation { @JsonIgnore def getFinal: AggregationOperation = { val newAggFunc = aggFunction match { - case AggregationFunction.COUNT => AggregationFunction.SUM - case a: AggregationFunction => a + // Both COUNT variants emit partial counts locally; the global stage sums them. + case AggregationFunction.COUNT => AggregationFunction.SUM + case AggregationFunction.COUNT_STAR => AggregationFunction.SUM + case a: AggregationFunction => a } val res = new AggregationOperation() res.aggFunction = newAggFunc @@ -138,16 +158,12 @@ class AggregationOperation { ) } - private def countAgg(): DistributedAggregation[Integer] = { + private def countAgg(countAllRows: Boolean): DistributedAggregation[Integer] = { + // COUNT(*) counts every row; COUNT(column) counts only rows with a non-null attribute. new DistributedAggregation[Integer]( () => 0, - (partial, tuple) => { - val inc = - if (attribute == null) 1 - else if (tuple.getField(attribute) != null) 1 - else 0 - partial + inc - }, + (partial, tuple) => + partial + (if (countAllRows || tuple.getField(attribute) != null) 1 else 0), (partial1, partial2) => partial1 + partial2, partial => partial ) diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aggregate/AggregateOpDescSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aggregate/AggregateOpDescSpec.scala index 681a1aa2be6..925c93acd1d 100644 --- a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aggregate/AggregateOpDescSpec.scala +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aggregate/AggregateOpDescSpec.scala @@ -87,4 +87,13 @@ class AggregateOpDescSpec extends AnyFlatSpec with Matchers { .getExternalOutputSchemas(Map(PortIdentity() -> input)) shouldBe Map(PortIdentity() -> Schema().add("avg", AttributeType.DOUBLE)) } + + it should "type a COUNT(*) result as INTEGER without looking up an input column" in { + // COUNT(*) ignores its attribute; even a stale/leaked attribute that does not exist + // in the input schema must not be dereferenced during schema propagation. + val input = Schema().add("v", AttributeType.LONG) + descWith(List.empty, aggOp(AggregationFunction.COUNT_STAR, "ghost", "row_count")) + .getExternalOutputSchemas(Map(PortIdentity() -> input)) shouldBe + Map(PortIdentity() -> Schema().add("row_count", AttributeType.INTEGER)) + } } diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aggregate/AggregateOpSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aggregate/AggregateOpSpec.scala index cb7925ec41d..3c363ffc67e 100644 --- a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aggregate/AggregateOpSpec.scala +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aggregate/AggregateOpSpec.scala @@ -61,6 +61,16 @@ class AggregateOpSpec extends AnyFunSuite { assert(attr.getType == AttributeType.INTEGER) } + test("getAggregationAttribute maps COUNT(*) result to INTEGER even with a null input type") { + // COUNT(*) has no input column, so schema propagation passes a null attrType; + // it must still resolve to INTEGER without dereferencing it. + val operation = makeAggregationOp(AggregationFunction.COUNT_STAR, "", "row_count") + val attr = operation.getAggregationAttribute(null) + + assert(attr.getName == "row_count") + assert(attr.getType == AttributeType.INTEGER) + } + test("getAggregationAttribute maps CONCAT result type to STRING") { val operation = makeAggregationOp(AggregationFunction.CONCAT, "tag", "all_tags") val attr = operation.getAggregationAttribute(AttributeType.INTEGER) @@ -142,13 +152,33 @@ class AggregateOpSpec extends AnyFunSuite { assert(math.abs(result - 4.0) < 1e-6) } - test("COUNT aggregation with attribute == null counts all rows") { + test("COUNT(*) aggregation counts all rows regardless of nulls") { + // COUNT(*) hides the attribute in the UI, so it arrives with a blank attribute. val schema = makeSchema("points" -> AttributeType.INTEGER) val tuple1 = makeTuple(schema, 10) val tuple2 = makeTuple(schema, null) val tuple3 = makeTuple(schema, 20) - val operation = makeAggregationOp(AggregationFunction.COUNT, null, "row_count") + val operation = makeAggregationOp(AggregationFunction.COUNT_STAR, "", "row_count") + val agg = operation.getAggFunc(AttributeType.INTEGER) + + var partial = agg.init() + partial = agg.iterate(partial, tuple1) + partial = agg.iterate(partial, tuple2) + partial = agg.iterate(partial, tuple3) + + val result = agg.finalAgg(partial).asInstanceOf[Number].intValue() + assert(result == 3) + } + + test("COUNT(*) aggregation ignores any attribute value and still counts every row") { + // Even if an attribute leaks through, COUNT(*) must count all rows (incl. null cells). + val schema = makeSchema("points" -> AttributeType.INTEGER) + val tuple1 = makeTuple(schema, 10) + val tuple2 = makeTuple(schema, null) + val tuple3 = makeTuple(schema, 20) + + val operation = makeAggregationOp(AggregationFunction.COUNT_STAR, "points", "row_count") val agg = operation.getAggFunc(AttributeType.INTEGER) var partial = agg.init() @@ -447,6 +477,15 @@ class AggregateOpSpec extends AnyFunSuite { assert(finalOp.resultAttribute == "price_count") } + test("getFinal rewrites COUNT(*) into SUM over the intermediate result attribute") { + val operation = makeAggregationOp(AggregationFunction.COUNT_STAR, "", "row_count") + val finalOp = operation.getFinal + + assert(finalOp.aggFunction == AggregationFunction.SUM) + assert(finalOp.attribute == "row_count") + assert(finalOp.resultAttribute == "row_count") + } + test("getFinal keeps non-COUNT aggregation function and rewires attribute to resultAttribute") { val operation = makeAggregationOp(AggregationFunction.SUM, "amount", "total_amount") val finalOp = operation.getFinal @@ -535,4 +574,33 @@ class AggregateOpSpec extends AnyFunSuite { assert(totalRevenue == 350) assert(rowCount == 3) } + + test("AggregateOpExec computes COUNT(*) over every row (including nulls) end-to-end") { + // region (ignored), revenue (one null). COUNT(*) ignores its attribute, so even a + // stale attribute that is absent from the schema must not be looked up; the + // executor still counts all 3 rows. + val schema = makeSchema( + "region" -> AttributeType.STRING, + "revenue" -> AttributeType.INTEGER + ) + + val tuple1 = makeTuple(schema, "west", 100) + val tuple2 = makeTuple(schema, "east", null) + val tuple3 = makeTuple(schema, "west", 50) + + val desc = new AggregateOpDesc() + desc.aggregations = + List(makeAggregationOp(AggregationFunction.COUNT_STAR, "ghost", "row_count")) + desc.groupByKeys = List() // global aggregation + + val exec = new AggregateOpExec(objectMapper.writeValueAsString(desc)) + exec.open() + exec.processTuple(tuple1, 0) + exec.processTuple(tuple2, 0) + exec.processTuple(tuple3, 0) + + val results = exec.onFinish(0).toList + assert(results.size == 1) + assert(results.head.getFields(0).asInstanceOf[Number].intValue() == 3) + } } diff --git a/docs/reference/operators/data-cleaning/aggregate/aggregate.md b/docs/reference/operators/data-cleaning/aggregate/aggregate.md index e2b3f0e9979..80012c1d84b 100644 --- a/docs/reference/operators/data-cleaning/aggregate/aggregate.md +++ b/docs/reference/operators/data-cleaning/aggregate/aggregate.md @@ -13,11 +13,13 @@ tags: [data-cleaning, aggregate] | Property | Requirement | Type | Default | Description | |----------|-------------|------|---------|-------------| | Aggregations | ✓ | List | - | Multiple aggregation functions (min: 1,
aggregations cannot be empty) | -| ↳ Aggregate Func | ✓ | sum, count, average, min, max, concat | - | Sum, count, average, min, max, or concat | -| ↳ Attribute | ✓ | String | - | Column to calculate average value | -| ↳ Result Attribute | ✓ | String | - | Column name of average result | +| ↳ Aggregate Func | ✓ | sum, count, count(*), average, min, max, concat | - | Sum, count, count(*), average, min, max, or concat | +| ↳ Attribute | ✓ (hidden for `count(*)`) | String | - | Column to aggregate on. Required for every function except `count(*)`, which counts all rows and hides this field | +| ↳ Result Attribute | ✓ | String | - | Column name of the aggregation result | | Group By Keys | | List | - | Group by columns | +> **Counting rows**: use `count(*)` to count every row (including rows with nulls) without selecting a column. Use `count` with a column to count only that column's non-null values. + ### Output Ports | Port | Mode | diff --git a/frontend/src/app/workspace/component/property-editor/operator-property-edit-frame/operator-property-edit-frame.component.ts b/frontend/src/app/workspace/component/property-editor/operator-property-edit-frame/operator-property-edit-frame.component.ts index 2512fecdacb..8590142ad23 100644 --- a/frontend/src/app/workspace/component/property-editor/operator-property-edit-frame/operator-property-edit-frame.component.ts +++ b/frontend/src/app/workspace/component/property-editor/operator-property-edit-frame/operator-property-edit-frame.component.ts @@ -77,6 +77,9 @@ import { map, switchMap, take } from "rxjs/operators"; Quill.register("modules/cursors", QuillCursors); +// The Aggregate "count(*)" function: counts all rows and takes no attribute. +export const COUNT_STAR = "count(*)"; + /** * Property Editor uses JSON Schema to automatically generate the form from the JSON Schema of an operator. * For example, the JSON Schema of Sentiment Analysis could be: @@ -545,6 +548,17 @@ export class OperatorPropertyEditFrameComponent implements OnInit, OnChanges, On mappedField.type = "datasetversionselector"; } + // Aggregate: the attribute is required for every function except count(*), which counts all rows. + // For count(*) the attribute is irrelevant, so hide the field and drop the required marker. + // Both react to the sibling aggFunction within the same row. + if (this.currentOperatorSchema?.operatorType === "Aggregate" && mappedField.key === "attribute") { + mappedField.expressions = { + ...mappedField.expressions, + "props.required": (field: FormlyFieldConfig) => field.parent?.model?.aggFunction !== COUNT_STAR, + hide: (field: FormlyFieldConfig) => field.parent?.model?.aggFunction === COUNT_STAR, + }; + } + if (this.currentOperatorSchema?.operatorType === "FileScanOp" && mappedField.key === "outputFileName") { mappedField.expressions = { ...mappedField.expressions,