diff --git a/env/BUILD b/env/BUILD index 8d477cc1f..55297b190 100644 --- a/env/BUILD +++ b/env/BUILD @@ -81,7 +81,10 @@ cc_library( "//runtime:runtime_builder_factory", "//runtime:runtime_options", "//runtime:standard_functions", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) @@ -236,10 +239,14 @@ cc_test( "//common:source", "//common:value", "//compiler", + "//extensions:math_ext", + "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//runtime", "//runtime:activation", + "//runtime:runtime_builder", + "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", diff --git a/env/config.cc b/env/config.cc index ccb4de34c..1be9d7696 100644 --- a/env/config.cc +++ b/env/config.cc @@ -60,7 +60,9 @@ absl::Status Config::AddExtensionConfig(std::string name, int version) { } return absl::AlreadyExistsError(absl::StrCat( "Extension '", name, "' version ", extension_config.version, - " is already included. Cannot also include version ", version)); + " is already included. Cannot also include version ", + version == ExtensionConfig::kLatest ? "'latest'" + : absl::StrCat(version))); } } extension_configs_.push_back( diff --git a/env/env_runtime.cc b/env/env_runtime.cc index 09bbcde04..33e0747cc 100644 --- a/env/env_runtime.cc +++ b/env/env_runtime.cc @@ -18,7 +18,10 @@ #include #include +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "env/config.h" #include "internal/status_macros.h" #include "runtime/runtime.h" @@ -29,6 +32,15 @@ namespace cel { +void EnvRuntime::RegisterExtensionFunctions( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable + function_registration_callback) { + extension_registry_.AddFunctionRegistration( + name, alias, version, std::move(function_registration_callback)); +} + absl::StatusOr EnvRuntime::CreateRuntimeBuilder() { const std::vector& extension_configs = config_.GetExtensionConfigs(); diff --git a/env/env_runtime.h b/env/env_runtime.h index ff62ec1d4..63473c295 100644 --- a/env/env_runtime.h +++ b/env/env_runtime.h @@ -18,7 +18,10 @@ #include #include +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "env/config.h" #include "env/internal/runtime_ext_registry.h" #include "runtime/runtime.h" @@ -41,6 +44,15 @@ namespace cel { // compilation. This ensures consistency between compilation and runtime. class EnvRuntime { public: + // Registers a function registration callback for an extension. The callback + // is invoked when a runtime is created, if the corresponding functions are + // enabled in the runtime config. + void RegisterExtensionFunctions( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable + function_registration_callback); + void SetDescriptorPool( std::shared_ptr descriptor_pool) { descriptor_pool_ = std::move(descriptor_pool); diff --git a/env/env_runtime_test.cc b/env/env_runtime_test.cc index 1c4205224..9b60cf591 100644 --- a/env/env_runtime_test.cc +++ b/env/env_runtime_test.cc @@ -31,10 +31,14 @@ #include "env/env_std_extensions.h" #include "env/env_yaml.h" #include "env/runtime_std_extensions.h" +#include "extensions/math_ext.h" +#include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "runtime/activation.h" #include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace cel { @@ -156,5 +160,41 @@ std::vector GetEnvRuntimeTestCases() { INSTANTIATE_TEST_SUITE_P(EnvRuntimeTest, EnvRuntimeTest, ValuesIn(GetEnvRuntimeTestCases())); +TEST(EnvRuntimeTest, RegisterExtensionFunctions) { + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + Config config; + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + + Env env; + env.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("math.sqrt(4) == 2.0")); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + env_runtime.RegisterExtensionFunctions( + "cel.lib.math", "math", 2, + [](cel::RuntimeBuilder& runtime_builder, + const cel::RuntimeOptions& opts) -> absl::Status { + return cel::extensions::RegisterMathExtensionFunctions( + runtime_builder.function_registry(), opts, 2); + }); + env_runtime.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + ASSERT_NE(program, nullptr); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()); +} } // namespace } // namespace cel