Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bundle/src/main/java/dev/cel/bundle/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ java_library(
":required_fields_checker",
"//:auto_value",
"//bundle:cel",
"//checker:proto_type_mask",
"//checker:standard_decl",
"//common:compiler_common",
"//common:container",
Expand Down
23 changes: 23 additions & 0 deletions bundle/src/main/java/dev/cel/bundle/CelEnvironment.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import dev.cel.checker.CelStandardDeclarations;
import dev.cel.checker.CelStandardDeclarations.StandardFunction;
import dev.cel.checker.CelStandardDeclarations.StandardOverload;
import dev.cel.checker.ProtoTypeMask;
import dev.cel.common.CelContainer;
import dev.cel.common.CelFunctionDecl;
import dev.cel.common.CelOptions;
Expand Down Expand Up @@ -134,6 +135,9 @@ public abstract class CelEnvironment {
/** Limits to set in the environment. */
public abstract ImmutableSet<Limit> limits();

/** Context variable to enable in the environment. */
public abstract Optional<ContextVariable> contextVariable();

/** Builder for {@link CelEnvironment}. */
@AutoValue.Builder
public abstract static class Builder {
Expand Down Expand Up @@ -199,6 +203,8 @@ public Builder setLimits(Limit... limits) {

public abstract Builder setLimits(ImmutableSet<Limit> limits);

public abstract Builder setContextVariable(ContextVariable contextVariable);

abstract CelEnvironment autoBuild();

@CheckReturnValue
Expand Down Expand Up @@ -258,6 +264,12 @@ public CelCompiler extend(CelCompiler celCompiler, CelOptions celOptions)

applyStandardLibrarySubset(compilerBuilder);

contextVariable()
.ifPresent(
cv ->
compilerBuilder.addProtoTypeMasks(
ProtoTypeMask.ofAllFields(cv.typeName()).withFieldsAsVariableDeclarations()));

return compilerBuilder.build();
} catch (RuntimeException e) {
throw new CelEnvironmentException(e.getMessage(), e);
Expand Down Expand Up @@ -406,6 +418,17 @@ private static CanonicalCelExtension getExtensionOrThrow(String extensionName) {
return extension;
}

/** Represents a context variable declaration. */
@AutoValue
public abstract static class ContextVariable {
/** Fully qualified type name of the context variable. */
public abstract String typeName();

public static ContextVariable create(String typeName) {
return new AutoValue_CelEnvironment_ContextVariable(typeName);
}
}

/** Represents a policy variable declaration. */
@AutoValue
public abstract static class VariableDecl {
Expand Down
34 changes: 34 additions & 0 deletions bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import dev.cel.bundle.CelEnvironment.Alias;
import dev.cel.bundle.CelEnvironment.ContextVariable;
import dev.cel.bundle.CelEnvironment.ExtensionConfig;
import dev.cel.bundle.CelEnvironment.FunctionDecl;
import dev.cel.bundle.CelEnvironment.LibrarySubset;
Expand Down Expand Up @@ -320,6 +321,36 @@ private ImmutableSet<String> parseAbbreviations(ParserContext<Node> ctx, Node no
return builder.build();
}

private ContextVariable parseContextVariable(ParserContext<Node> ctx, Node node) {
long id = ctx.collectMetadata(node);
if (!assertYamlType(ctx, id, node, YamlNodeType.MAP)) {
return ContextVariable.create("");
}

MappingNode mapNode = (MappingNode) node;
String typeName = "";
for (NodeTuple nodeTuple : mapNode.getValue()) {
Node keyNode = nodeTuple.getKeyNode();
long keyId = ctx.collectMetadata(keyNode);
Node valueNode = nodeTuple.getValueNode();
String keyName = ((ScalarNode) keyNode).getValue();
switch (keyName) {
case "type_name":
typeName = newString(ctx, valueNode);
break;
default:
ctx.reportError(keyId, String.format("Unsupported context_variable tag: %s", keyName));
break;
}
}

if (typeName.isEmpty()) {
ctx.reportError(id, "Missing required attribute(s): type_name");
}

return ContextVariable.create(typeName);
}

private ImmutableSet<VariableDecl> parseVariables(ParserContext<Node> ctx, Node node) {
long valueId = ctx.collectMetadata(node);
ImmutableSet.Builder<VariableDecl> variableSetBuilder = ImmutableSet.builder();
Expand Down Expand Up @@ -900,6 +931,9 @@ private CelEnvironment.Builder parseConfig(ParserContext<Node> ctx, Node node) {
case "limits":
builder.setLimits(parseLimits(ctx, valueNode));
break;
case "context_variable":
builder.setContextVariable(parseContextVariable(ctx, valueNode));
break;
default:
ctx.reportError(id, "Unknown config tag: " + fieldName);
// continue handling the rest of the nodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ java_library(
deps = [
":cel_expression_source",
":default_result_matcher",
":registry_utils",
":result_matcher",
"//:auto_value",
"//bundle:cel",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,21 +140,21 @@ public Optional<CelDescriptors> celDescriptors() {
return Optional.empty();
}

/** Returns a unified set of {@link CelDescriptors} combined from all descriptor sources. */
@Memoized
public Optional<TypeRegistry> typeRegistry() {
public Optional<CelDescriptors> mergedDescriptors() {
if (fileTypes().isEmpty() && !fileDescriptorSetPath().isPresent()) {
return Optional.empty();
}
TypeRegistry.Builder builder = TypeRegistry.newBuilder();
if (!fileTypes().isEmpty()) {
builder.add(
CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileTypes())
.messageTypeDescriptors());
}
if (celDescriptors().isPresent()) {
builder.add(celDescriptors().get().messageTypeDescriptors());
}
return Optional.of(builder.build());
ImmutableSet.Builder<FileDescriptor> allFiles =
ImmutableSet.<FileDescriptor>builder().addAll(fileTypes());
celDescriptors().ifPresent(d -> allFiles.addAll(d.fileDescriptors()));
return Optional.of(CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(allFiles.build()));
}

@Memoized
public Optional<TypeRegistry> typeRegistry() {
return mergedDescriptors().map(RegistryUtils::getTypeRegistry);
}

public abstract Optional<ExtensionRegistry> extensionRegistry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,18 +360,33 @@ private static Object getEvaluationResultWithMessage(
}

private static Message unpackAny(Any any, CelTestContext celTestContext) throws IOException {
if (!celTestContext.fileDescriptorSetPath().isPresent()) {
throw new IllegalArgumentException(
"Proto descriptors are required for unpacking Any messages.");
TypeRegistry typeRegistry =
celTestContext
.typeRegistry()
.orElseThrow(
() ->
new IllegalArgumentException(
"Proto descriptors or type registry are required for unpacking Any"
+ " messages."));

Descriptor descriptor = typeRegistry.getDescriptorForTypeUrl(any.getTypeUrl());
if (descriptor == null) {
throw new IllegalArgumentException("Descriptor not found for type URL: " + any.getTypeUrl());
}
Descriptor descriptor =
RegistryUtils.getTypeRegistry(celTestContext.celDescriptors().get())
.getDescriptorForTypeUrl(any.getTypeUrl());

ExtensionRegistry extensionRegistry =
celTestContext
.extensionRegistry()
.orElseGet(
() ->
celTestContext
.mergedDescriptors()
.map(RegistryUtils::getExtensionRegistry)
.orElseGet(ExtensionRegistry::getEmptyRegistry));

return DynamicMessage.getDefaultInstance(descriptor)
.getParserForType()
.parseFrom(
any.getValue(),
RegistryUtils.getExtensionRegistry(celTestContext.celDescriptors().get()));
.parseFrom(any.getValue(), extensionRegistry);
}

private static Message getEvaluatedContextExpr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ public void runTest_missingProtoDescriptors_failure() throws Exception {

assertThat(thrown)
.hasMessageThat()
.contains("Proto descriptors are required for unpacking Any messages.");
.contains("Proto descriptors or type registry are required for unpacking Any messages");
}

@Test
Expand Down
Loading