Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions env/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ cc_library(
"//runtime:runtime_builder_factory",
"//runtime:runtime_options",
"//runtime:standard_functions",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_protobuf//:protobuf",
],
)
Expand Down Expand Up @@ -236,10 +239,14 @@ cc_test(
"//common:source",
"//common:value",
"//compiler",
"//extensions:math_ext",
"//internal:status_macros",
"//internal:testing",
"//internal:testing_descriptor_pool",
"//runtime",
"//runtime:activation",
"//runtime:runtime_builder",
"//runtime:runtime_options",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_protobuf//:protobuf",
Expand Down
4 changes: 3 additions & 1 deletion env/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ absl::Status Config::AddExtensionConfig(std::string name, int version) {
}
return absl::AlreadyExistsError(absl::StrCat(
"Extension '", name, "' version ", extension_config.version,
" is already included. Cannot also include version ", version));
" is already included. Cannot also include version ",
version == ExtensionConfig::kLatest ? "'latest'"
: absl::StrCat(version)));
}
}
extension_configs_.push_back(
Expand Down
12 changes: 12 additions & 0 deletions env/env_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
#include <utility>
#include <vector>

#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "env/config.h"
#include "internal/status_macros.h"
#include "runtime/runtime.h"
Expand All @@ -29,6 +32,15 @@

namespace cel {

void EnvRuntime::RegisterExtensionFunctions(
absl::string_view name, absl::string_view alias, int version,
absl::AnyInvocable<absl::Status(RuntimeBuilder&, const RuntimeOptions&)
const>
function_registration_callback) {
extension_registry_.AddFunctionRegistration(
name, alias, version, std::move(function_registration_callback));
}

absl::StatusOr<RuntimeBuilder> EnvRuntime::CreateRuntimeBuilder() {
const std::vector<Config::ExtensionConfig>& extension_configs =
config_.GetExtensionConfigs();
Expand Down
12 changes: 12 additions & 0 deletions env/env_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
#include <memory>
#include <utility>

#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "env/config.h"
#include "env/internal/runtime_ext_registry.h"
#include "runtime/runtime.h"
Expand All @@ -41,6 +44,15 @@ namespace cel {
// compilation. This ensures consistency between compilation and runtime.
class EnvRuntime {
public:
// Registers a function registration callback for an extension. The callback
// is invoked when a runtime is created, if the corresponding functions are
// enabled in the runtime config.
void RegisterExtensionFunctions(
absl::string_view name, absl::string_view alias, int version,
absl::AnyInvocable<absl::Status(RuntimeBuilder&, const RuntimeOptions&)
const>
function_registration_callback);

void SetDescriptorPool(
std::shared_ptr<const google::protobuf::DescriptorPool> descriptor_pool) {
descriptor_pool_ = std::move(descriptor_pool);
Expand Down
40 changes: 40 additions & 0 deletions env/env_runtime_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@
#include "env/env_std_extensions.h"
#include "env/env_yaml.h"
#include "env/runtime_std_extensions.h"
#include "extensions/math_ext.h"
#include "internal/status_macros.h"
#include "internal/testing.h"
#include "internal/testing_descriptor_pool.h"
#include "runtime/activation.h"
#include "runtime/runtime.h"
#include "runtime/runtime_builder.h"
#include "runtime/runtime_options.h"
#include "google/protobuf/arena.h"

namespace cel {
Expand Down Expand Up @@ -156,5 +160,41 @@ std::vector<TestCase> GetEnvRuntimeTestCases() {
INSTANTIATE_TEST_SUITE_P(EnvRuntimeTest, EnvRuntimeTest,
ValuesIn(GetEnvRuntimeTestCases()));

TEST(EnvRuntimeTest, RegisterExtensionFunctions) {
auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool();
Config config;
ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk());

Env env;
env.SetDescriptorPool(descriptor_pool);
RegisterStandardExtensions(env);
env.SetConfig(config);
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Compiler> compiler, env.NewCompiler());
ASSERT_OK_AND_ASSIGN(ValidationResult result,
compiler->Compile("math.sqrt(4) == 2.0"));
EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError();
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Ast> ast, result.ReleaseAst());

EnvRuntime env_runtime;
env_runtime.SetDescriptorPool(descriptor_pool);
env_runtime.RegisterExtensionFunctions(
"cel.lib.math", "math", 2,
[](cel::RuntimeBuilder& runtime_builder,
const cel::RuntimeOptions& opts) -> absl::Status {
return cel::extensions::RegisterMathExtensionFunctions(
runtime_builder.function_registry(), opts, 2);
});
env_runtime.SetConfig(config);
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Runtime> runtime,
env_runtime.NewRuntime());
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Program> program,
runtime->CreateProgram(std::move(ast)));
ASSERT_NE(program, nullptr);

google::protobuf::Arena arena;
Activation activation;
ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation));
EXPECT_TRUE(value.GetBool());
}
} // namespace
} // namespace cel