Skip to content

Commit 0392908

Browse files
l46kokcopybara-github
authored andcommitted
Allow expected result type to be configured per AST validator instance
PiperOrigin-RevId: 831021405
1 parent 9083e06 commit 0392908

File tree

9 files changed

+68
-20
lines changed

9 files changed

+68
-20
lines changed

validator/src/main/java/dev/cel/validator/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ java_library(
7272
"//common:compiler_common",
7373
"//common:source_location",
7474
"//common/navigation",
75+
"//common/types",
76+
"//common/types:type_providers",
7577
"@maven//:com_google_guava_guava",
7678
],
7779
)

validator/src/main/java/dev/cel/validator/CelAstValidator.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,22 @@
2121
import dev.cel.common.CelSource;
2222
import dev.cel.common.CelSourceLocation;
2323
import dev.cel.common.navigation.CelNavigableAst;
24+
import dev.cel.common.types.CelType;
25+
import dev.cel.common.types.SimpleType;
2426
import java.util.Optional;
2527

2628
/** Public interface for performing a single, custom validation on an AST. */
2729
public interface CelAstValidator {
2830

2931
void validate(CelNavigableAst navigableAst, Cel cel, IssuesFactory issuesFactory);
3032

33+
/** Enforces a specific expected result type during validation, if set. */
34+
default CelType expectedResultType() {
35+
return SimpleType.DYN;
36+
}
37+
3138
/** Factory for populating issues while performing AST validation. */
32-
public final class IssuesFactory {
39+
final class IssuesFactory {
3340
private final ImmutableList.Builder<CelIssue> issuesBuilder;
3441
private final CelNavigableAst navigableAst;
3542

validator/src/main/java/dev/cel/validator/CelValidatorImpl.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@ public CelValidationResult validate(CelAbstractSyntaxTree ast) {
4848
ImmutableList.Builder<CelIssue> issueBuilder = ImmutableList.builder();
4949

5050
for (CelAstValidator validator : astValidators) {
51+
Cel celEnv = this.cel.toCelBuilder().setResultType(validator.expectedResultType()).build();
52+
5153
CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast);
5254
IssuesFactory issuesFactory = new IssuesFactory(navigableAst);
53-
validator.validate(navigableAst, cel, issuesFactory);
55+
validator.validate(navigableAst, celEnv, issuesFactory);
5456
issueBuilder.addAll(issuesFactory.getIssues());
5557
}
5658

validator/src/main/java/dev/cel/validator/validators/BUILD.bazel

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ java_library(
1717
],
1818
deps = [
1919
":literal_validator",
20-
"@maven//:com_google_protobuf_protobuf_java",
20+
"//common/types",
21+
"//common/types:type_providers",
2122
],
2223
)
2324

@@ -30,7 +31,8 @@ java_library(
3031
],
3132
deps = [
3233
":literal_validator",
33-
"@maven//:com_google_protobuf_protobuf_java",
34+
"//common/types",
35+
"//common/types:type_providers",
3436
],
3537
)
3638

@@ -116,6 +118,7 @@ java_library(
116118
"//common/ast",
117119
"//common/ast:expr_factory",
118120
"//common/navigation",
121+
"//common/types:type_providers",
119122
"//runtime",
120123
"//validator:ast_validator",
121124
"@maven//:com_google_errorprone_error_prone_annotations",

validator/src/main/java/dev/cel/validator/validators/DurationLiteralValidator.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414

1515
package dev.cel.validator.validators;
1616

17+
import dev.cel.common.types.CelType;
18+
import dev.cel.common.types.SimpleType;
1719
import java.time.Duration;
1820

1921
/** DurationLiteralValidator ensures that duration literal arguments are valid. */
2022
public final class DurationLiteralValidator extends LiteralValidator {
2123
public static final DurationLiteralValidator INSTANCE =
22-
new DurationLiteralValidator("duration", Duration.class);
24+
new DurationLiteralValidator("duration", Duration.class, SimpleType.DURATION);
2325

24-
private DurationLiteralValidator(String functionName, Class<?> expectedResultType) {
25-
super(functionName, expectedResultType);
26+
private DurationLiteralValidator(
27+
String functionName, Class<?> expectedJavaType, CelType expectedResultType) {
28+
super(functionName, expectedJavaType, expectedResultType);
2629
}
2730
}

validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,16 @@ public final class HomogeneousLiteralValidator implements CelAstValidator {
3838
private final ImmutableSet<String> exemptFunctions;
3939

4040
/**
41-
* Construct a new instance of {@link HomogeneousLiteralValidator}. This validator will not for
42-
* functions in {@code exemptFunctions}.
41+
* Construct a new instance of {@link HomogeneousLiteralValidator}. This validator will not run
42+
* for functions in {@code exemptFunctions}.
4343
*/
4444
public static HomogeneousLiteralValidator newInstance(Iterable<String> exemptFunctions) {
4545
return new HomogeneousLiteralValidator(exemptFunctions);
4646
}
4747

4848
/**
49-
* Construct a new instance of {@link HomogeneousLiteralValidator}. This validator will not for
50-
* functions in {@code exemptFunctions}.
49+
* Construct a new instance of {@link HomogeneousLiteralValidator}. This validator will not run
50+
* for functions in {@code exemptFunctions}.
5151
*/
5252
public static HomogeneousLiteralValidator newInstance(String... exemptFunctions) {
5353
return newInstance(Arrays.asList(exemptFunctions));

validator/src/main/java/dev/cel/validator/validators/LiteralValidator.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import dev.cel.common.ast.CelExprFactory;
2525
import dev.cel.common.navigation.CelNavigableAst;
2626
import dev.cel.common.navigation.CelNavigableExpr;
27+
import dev.cel.common.types.CelType;
2728
import dev.cel.runtime.CelEvaluationException;
2829
import dev.cel.validator.CelAstValidator;
2930

@@ -34,13 +35,21 @@
3435
*/
3536
public abstract class LiteralValidator implements CelAstValidator {
3637
private final String functionName;
37-
private final Class<?> expectedResultType;
38+
private final Class<?> expectedJavaType;
39+
private final CelType expectedResultType;
3840

39-
protected LiteralValidator(String functionName, Class<?> expectedResultType) {
41+
protected LiteralValidator(
42+
String functionName, Class<?> expectedJavaType, CelType expectedResultType) {
4043
this.functionName = functionName;
44+
this.expectedJavaType = expectedJavaType;
4145
this.expectedResultType = expectedResultType;
4246
}
4347

48+
@Override
49+
public CelType expectedResultType() {
50+
return expectedResultType;
51+
}
52+
4453
@Override
4554
public void validate(CelNavigableAst navigableAst, Cel cel, IssuesFactory issuesFactory) {
4655
CelExprFactory exprFactory = CelExprFactory.newInstance();
@@ -61,7 +70,7 @@ public void validate(CelNavigableAst navigableAst, Cel cel, IssuesFactory issues
6170
CelExpr callExpr =
6271
exprFactory.newGlobalCall(functionName, exprFactory.newConstant(expr.constant()));
6372
try {
64-
evaluateExpr(cel, callExpr, expectedResultType);
73+
evaluateExpr(cel, callExpr, expectedJavaType);
6574
} catch (Exception e) {
6675
issuesFactory.addError(
6776
expr.id(),
@@ -72,18 +81,18 @@ public void validate(CelNavigableAst navigableAst, Cel cel, IssuesFactory issues
7281
}
7382

7483
@CanIgnoreReturnValue
75-
private static Object evaluateExpr(Cel cel, CelExpr expr, Class<?> expectedResultType)
84+
private static Object evaluateExpr(Cel cel, CelExpr expr, Class<?> expectedJavaType)
7685
throws CelValidationException, CelEvaluationException {
7786
CelAbstractSyntaxTree ast =
7887
CelAbstractSyntaxTree.newParsedAst(expr, CelSource.newBuilder().build());
7988
ast = cel.check(ast).getAst();
8089
Object result = cel.createProgram(ast).eval();
8190

82-
if (!expectedResultType.isInstance(result)) {
91+
if (!expectedJavaType.isInstance(result)) {
8392
throw new IllegalStateException(
8493
String.format(
8594
"Expected %s type but got %s instead",
86-
expectedResultType.getName(), result.getClass().getName()));
95+
expectedJavaType.getName(), result.getClass().getName()));
8796
}
8897
return result;
8998
}

validator/src/main/java/dev/cel/validator/validators/TimestampLiteralValidator.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414

1515
package dev.cel.validator.validators;
1616

17+
import dev.cel.common.types.CelType;
18+
import dev.cel.common.types.SimpleType;
1719
import java.time.Instant;
1820

1921
/** TimestampLiteralValidator ensures that timestamp literal arguments are valid. */
2022
public final class TimestampLiteralValidator extends LiteralValidator {
2123
public static final TimestampLiteralValidator INSTANCE =
22-
new TimestampLiteralValidator("timestamp", Instant.class);
24+
new TimestampLiteralValidator("timestamp", Instant.class, SimpleType.TIMESTAMP);
2325

24-
private TimestampLiteralValidator(String functionName, Class<?> expectedResultType) {
25-
super(functionName, expectedResultType);
26+
private TimestampLiteralValidator(
27+
String functionName, Class<?> expectedJavaType, CelType expectedResultType) {
28+
super(functionName, expectedJavaType, expectedResultType);
2629
}
2730
}

validator/src/test/java/dev/cel/validator/validators/TimestampLiteralValidatorTest.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,23 @@ public void parentIsNotCallExpr_doesNotThrow(String source) throws Exception {
200200
assertThat(result.hasError()).isFalse();
201201
assertThat(result.getAllIssues()).isEmpty();
202202
}
203+
204+
@Test
205+
public void env_withSetResultType_success() throws Exception {
206+
Cel cel =
207+
CelFactory.standardCelBuilder()
208+
.setOptions(CelOptions.current().enableTimestampEpoch(true).build())
209+
.setResultType(SimpleType.BOOL)
210+
.build();
211+
CelValidator validator =
212+
CelValidatorFactory.standardCelValidatorBuilder(cel)
213+
.addAstValidators(TimestampLiteralValidator.INSTANCE)
214+
.build();
215+
CelAbstractSyntaxTree ast = cel.compile("timestamp(123) == timestamp(123)").getAst();
216+
217+
CelValidationResult result = validator.validate(ast);
218+
219+
assertThat(result.hasError()).isFalse();
220+
assertThat(result.getAllIssues()).isEmpty();
221+
}
203222
}

0 commit comments

Comments
 (0)