diff --git a/bundle/src/main/java/dev/cel/bundle/BUILD.bazel b/bundle/src/main/java/dev/cel/bundle/BUILD.bazel index 742f718f1..716442849 100644 --- a/bundle/src/main/java/dev/cel/bundle/BUILD.bazel +++ b/bundle/src/main/java/dev/cel/bundle/BUILD.bazel @@ -104,6 +104,7 @@ java_library( ":required_fields_checker", "//:auto_value", "//bundle:cel", + "//checker:proto_type_mask", "//checker:standard_decl", "//common:compiler_common", "//common:container", diff --git a/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java b/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java index b85f16cb1..ccbaef61b 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java +++ b/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java @@ -30,6 +30,7 @@ import dev.cel.checker.CelStandardDeclarations; import dev.cel.checker.CelStandardDeclarations.StandardFunction; import dev.cel.checker.CelStandardDeclarations.StandardOverload; +import dev.cel.checker.ProtoTypeMask; import dev.cel.common.CelContainer; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; @@ -134,6 +135,9 @@ public abstract class CelEnvironment { /** Limits to set in the environment. */ public abstract ImmutableSet limits(); + /** Context variable to enable in the environment. */ + public abstract Optional contextVariable(); + /** Builder for {@link CelEnvironment}. */ @AutoValue.Builder public abstract static class Builder { @@ -199,6 +203,8 @@ public Builder setLimits(Limit... limits) { public abstract Builder setLimits(ImmutableSet limits); + public abstract Builder setContextVariable(ContextVariable contextVariable); + abstract CelEnvironment autoBuild(); @CheckReturnValue @@ -258,6 +264,12 @@ public CelCompiler extend(CelCompiler celCompiler, CelOptions celOptions) applyStandardLibrarySubset(compilerBuilder); + contextVariable() + .ifPresent( + cv -> + compilerBuilder.addProtoTypeMasks( + ProtoTypeMask.ofAllFields(cv.typeName()).withFieldsAsVariableDeclarations())); + return compilerBuilder.build(); } catch (RuntimeException e) { throw new CelEnvironmentException(e.getMessage(), e); @@ -406,6 +418,17 @@ private static CanonicalCelExtension getExtensionOrThrow(String extensionName) { return extension; } + /** Represents a context variable declaration. */ + @AutoValue + public abstract static class ContextVariable { + /** Fully qualified type name of the context variable. */ + public abstract String typeName(); + + public static ContextVariable create(String typeName) { + return new AutoValue_CelEnvironment_ContextVariable(typeName); + } + } + /** Represents a policy variable declaration. */ @AutoValue public abstract static class VariableDecl { diff --git a/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java b/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java index f129d9f5d..14f1c93d8 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java +++ b/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java @@ -28,6 +28,7 @@ import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.CanIgnoreReturnValue; import dev.cel.bundle.CelEnvironment.Alias; +import dev.cel.bundle.CelEnvironment.ContextVariable; import dev.cel.bundle.CelEnvironment.ExtensionConfig; import dev.cel.bundle.CelEnvironment.FunctionDecl; import dev.cel.bundle.CelEnvironment.LibrarySubset; @@ -320,6 +321,36 @@ private ImmutableSet parseAbbreviations(ParserContext ctx, Node no return builder.build(); } + private ContextVariable parseContextVariable(ParserContext ctx, Node node) { + long id = ctx.collectMetadata(node); + if (!assertYamlType(ctx, id, node, YamlNodeType.MAP)) { + return ContextVariable.create(""); + } + + MappingNode mapNode = (MappingNode) node; + String typeName = ""; + for (NodeTuple nodeTuple : mapNode.getValue()) { + Node keyNode = nodeTuple.getKeyNode(); + long keyId = ctx.collectMetadata(keyNode); + Node valueNode = nodeTuple.getValueNode(); + String keyName = ((ScalarNode) keyNode).getValue(); + switch (keyName) { + case "type_name": + typeName = newString(ctx, valueNode); + break; + default: + ctx.reportError(keyId, String.format("Unsupported context_variable tag: %s", keyName)); + break; + } + } + + if (typeName.isEmpty()) { + ctx.reportError(id, "Missing required attribute(s): type_name"); + } + + return ContextVariable.create(typeName); + } + private ImmutableSet parseVariables(ParserContext ctx, Node node) { long valueId = ctx.collectMetadata(node); ImmutableSet.Builder variableSetBuilder = ImmutableSet.builder(); @@ -900,6 +931,9 @@ private CelEnvironment.Builder parseConfig(ParserContext ctx, Node node) { case "limits": builder.setLimits(parseLimits(ctx, valueNode)); break; + case "context_variable": + builder.setContextVariable(parseContextVariable(ctx, valueNode)); + break; default: ctx.reportError(id, "Unknown config tag: " + fieldName); // continue handling the rest of the nodes diff --git a/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel b/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel index d0fed9bea..677884a8a 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel @@ -161,6 +161,7 @@ java_library( deps = [ ":cel_expression_source", ":default_result_matcher", + ":registry_utils", ":result_matcher", "//:auto_value", "//bundle:cel", diff --git a/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java b/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java index 1be0bab25..6ef988a44 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java @@ -140,21 +140,21 @@ public Optional celDescriptors() { return Optional.empty(); } + /** Returns a unified set of {@link CelDescriptors} combined from all descriptor sources. */ @Memoized - public Optional typeRegistry() { + public Optional mergedDescriptors() { if (fileTypes().isEmpty() && !fileDescriptorSetPath().isPresent()) { return Optional.empty(); } - TypeRegistry.Builder builder = TypeRegistry.newBuilder(); - if (!fileTypes().isEmpty()) { - builder.add( - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileTypes()) - .messageTypeDescriptors()); - } - if (celDescriptors().isPresent()) { - builder.add(celDescriptors().get().messageTypeDescriptors()); - } - return Optional.of(builder.build()); + ImmutableSet.Builder allFiles = + ImmutableSet.builder().addAll(fileTypes()); + celDescriptors().ifPresent(d -> allFiles.addAll(d.fileDescriptors())); + return Optional.of(CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(allFiles.build())); + } + + @Memoized + public Optional typeRegistry() { + return mergedDescriptors().map(RegistryUtils::getTypeRegistry); } public abstract Optional extensionRegistry(); diff --git a/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java b/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java index 69c365972..1d3e49fbe 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java @@ -360,18 +360,33 @@ private static Object getEvaluationResultWithMessage( } private static Message unpackAny(Any any, CelTestContext celTestContext) throws IOException { - if (!celTestContext.fileDescriptorSetPath().isPresent()) { - throw new IllegalArgumentException( - "Proto descriptors are required for unpacking Any messages."); + TypeRegistry typeRegistry = + celTestContext + .typeRegistry() + .orElseThrow( + () -> + new IllegalArgumentException( + "Proto descriptors or type registry are required for unpacking Any" + + " messages.")); + + Descriptor descriptor = typeRegistry.getDescriptorForTypeUrl(any.getTypeUrl()); + if (descriptor == null) { + throw new IllegalArgumentException("Descriptor not found for type URL: " + any.getTypeUrl()); } - Descriptor descriptor = - RegistryUtils.getTypeRegistry(celTestContext.celDescriptors().get()) - .getDescriptorForTypeUrl(any.getTypeUrl()); + + ExtensionRegistry extensionRegistry = + celTestContext + .extensionRegistry() + .orElseGet( + () -> + celTestContext + .mergedDescriptors() + .map(RegistryUtils::getExtensionRegistry) + .orElseGet(ExtensionRegistry::getEmptyRegistry)); + return DynamicMessage.getDefaultInstance(descriptor) .getParserForType() - .parseFrom( - any.getValue(), - RegistryUtils.getExtensionRegistry(celTestContext.celDescriptors().get())); + .parseFrom(any.getValue(), extensionRegistry); } private static Message getEvaluatedContextExpr( diff --git a/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java b/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java index b83375b35..112ef1f82 100644 --- a/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java +++ b/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java @@ -262,7 +262,7 @@ public void runTest_missingProtoDescriptors_failure() throws Exception { assertThat(thrown) .hasMessageThat() - .contains("Proto descriptors are required for unpacking Any messages."); + .contains("Proto descriptors or type registry are required for unpacking Any messages"); } @Test