From 53bb24ae6d8498ac336219e4befbf43271aa91c3 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 8 Jan 2026 15:32:00 -0800 Subject: [PATCH] Checker: Update variable resolution with proceeding `.`. `[0].exists(x, .x == 2) // where x := 2 as a context variable` PiperOrigin-RevId: 853907987 --- checker/internal/type_check_env.cc | 8 +- checker/internal/type_check_env.h | 22 ++-- checker/internal/type_checker_impl.cc | 115 +++++++++++++++------ checker/internal/type_checker_impl_test.cc | 55 ++++++++-- 4 files changed, 140 insertions(+), 60 deletions(-) diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index 91bfbaafa..d856a7230 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -16,7 +16,6 @@ #include #include -#include #include "absl/base/nullability.h" #include "absl/status/statusor.h" @@ -134,7 +133,7 @@ absl::StatusOr> TypeCheckEnv::LookupTypeConstant( google::protobuf::Arena* absl_nonnull arena, absl::string_view name) const { CEL_ASSIGN_OR_RETURN(absl::optional type, LookupTypeName(name)); if (type.has_value()) { - return MakeVariableDecl(std::string(type->name()), TypeType(arena, *type)); + return MakeVariableDecl(type->name(), TypeType(arena, *type)); } if (name.find('.') != name.npos) { @@ -185,7 +184,7 @@ absl::StatusOr> TypeCheckEnv::LookupStructField( return absl::nullopt; } -const VariableDecl* absl_nullable VariableScope::LookupVariable( +const VariableDecl* absl_nullable VariableScope::LookupLocalVariable( absl::string_view name) const { const VariableScope* scope = this; while (scope != nullptr) { @@ -194,8 +193,7 @@ const VariableDecl* absl_nullable VariableScope::LookupVariable( } scope = scope->parent_; } - - return env_->LookupVariable(name); + return nullptr; } } // namespace cel::checker_internal diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 0b2ad31ed..f7f81f2a9 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -42,13 +42,11 @@ class TypeCheckEnv; // Helper class for managing nested scopes and the local variables they // implicitly declare. // -// Nested scopes have a lifetime dependency on any parent scopes and the -// parent Type environment. Nested scopes should generally be managed by -// unique_ptrs. +// Nested scopes have a lifetime dependency on any parent scopes and should +// generally be managed by unique_ptrs. class VariableScope { public: - explicit VariableScope(const TypeCheckEnv& env ABSL_ATTRIBUTE_LIFETIME_BOUND) - : env_(&env), parent_(nullptr) {} + explicit VariableScope() : parent_(nullptr) {} VariableScope(const VariableScope&) = delete; VariableScope& operator=(const VariableScope&) = delete; @@ -61,18 +59,17 @@ class VariableScope { std::unique_ptr MakeNestedScope() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - return absl::WrapUnique(new VariableScope(*env_, this)); + return absl::WrapUnique(new VariableScope(this)); } - const VariableDecl* absl_nullable LookupVariable( + const VariableDecl* absl_nullable LookupLocalVariable( absl::string_view name) const; private: - VariableScope(const TypeCheckEnv& env ABSL_ATTRIBUTE_LIFETIME_BOUND, - const VariableScope* parent ABSL_ATTRIBUTE_LIFETIME_BOUND) - : env_(&env), parent_(parent) {} + explicit VariableScope( + const VariableScope* parent ABSL_ATTRIBUTE_LIFETIME_BOUND) + : parent_(parent) {} - const TypeCheckEnv* absl_nonnull env_; const VariableScope* absl_nullable parent_; absl::flat_hash_map variables_; }; @@ -190,9 +187,6 @@ class TypeCheckEnv { TypeCheckEnv MakeExtendedEnvironment() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return TypeCheckEnv(this); } - VariableScope MakeVariableScope() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - return VariableScope(*this); - } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { return descriptor_pool_.get(); diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 6d77bb2e2..55e68d1d2 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -26,6 +26,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -234,6 +235,11 @@ class ResolveVisitor : public AstVisitorBase { bool namespace_rewrite; }; + struct AttributeResolution { + const VariableDecl* decl; + bool requires_disambiguation; + }; + ResolveVisitor(absl::string_view container, NamespaceGenerator namespace_generator, const TypeCheckEnv& env, const Ast& ast, @@ -246,7 +252,7 @@ class ResolveVisitor : public AstVisitorBase { inference_context_(&inference_context), issues_(&issues), ast_(&ast), - root_scope_(env.MakeVariableScope()), + root_scope_(), arena_(arena), current_scope_(&root_scope_) {} @@ -294,7 +300,7 @@ class ResolveVisitor : public AstVisitorBase { return functions_; } - const absl::flat_hash_map& attributes() + const absl::flat_hash_map& attributes() const { return attributes_; } @@ -344,9 +350,13 @@ class ResolveVisitor : public AstVisitorBase { absl::string_view function_name, int arg_count, bool is_receiver); - // Resolves the function call shape (i.e. the number of arguments and call - // style) for the given function call. - const VariableDecl* absl_nullable LookupIdentifier(absl::string_view name); + // Resolves a global identifier (i.e. declared in the CEL environment). + const VariableDecl* absl_nullable LookupGlobalIdentifier( + absl::string_view name); + + // Resolves a local identifier (i.e. a bind or comrprehension var). + const VariableDecl* absl_nullable LookupLocalIdentifier( + absl::string_view name); // Resolves the applicable function overloads for the given function call. // @@ -476,7 +486,7 @@ class ResolveVisitor : public AstVisitorBase { // References that were resolved and may require AST rewrites. absl::flat_hash_map functions_; - absl::flat_hash_map attributes_; + absl::flat_hash_map attributes_; absl::flat_hash_map struct_types_; absl::flat_hash_map types_; @@ -967,10 +977,20 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, types_[&expr] = resolution->result_type; } -const VariableDecl* absl_nullable ResolveVisitor::LookupIdentifier( +const VariableDecl* absl_nullable ResolveVisitor::LookupLocalIdentifier( absl::string_view name) { - if (const VariableDecl* decl = current_scope_->LookupVariable(name); - decl != nullptr) { + // Note: if we see a leading dot, this shouldn't resolve to a local variable, + // but we need to check whether we need to disambiguate against a global in + // the reference map. + if (absl::StartsWith(name, ".")) { + name = name.substr(1); + } + return current_scope_->LookupLocalVariable(name); +} + +const VariableDecl* absl_nullable ResolveVisitor::LookupGlobalIdentifier( + absl::string_view name) { + if (const VariableDecl* decl = env_->LookupVariable(name); decl != nullptr) { return decl; } absl::StatusOr> constant = @@ -996,22 +1016,34 @@ const VariableDecl* absl_nullable ResolveVisitor::LookupIdentifier( void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, absl::string_view name) { + // Local variables (comprehension, bind) are simple identifiers so we can + // skip generating the different namespace-qualified candidates. + const VariableDecl* local_decl = LookupLocalIdentifier(name); + + if (local_decl != nullptr && !absl::StartsWith(name, ".")) { + attributes_[&expr] = {local_decl, false}; + types_[&expr] = + inference_context_->InstantiateTypeParams(local_decl->type()); + return; + } + const VariableDecl* decl = nullptr; namespace_generator_.GenerateCandidates( name, [&decl, this](absl::string_view candidate) { - decl = LookupIdentifier(candidate); + decl = LookupGlobalIdentifier(candidate); // continue searching. return decl == nullptr; }); - if (decl == nullptr) { - ReportMissingReference(expr, name); - types_[&expr] = ErrorType(); + if (decl != nullptr) { + attributes_[&expr] = {decl, + /* requires_disambiguation= */ local_decl != nullptr}; + types_[&expr] = inference_context_->InstantiateTypeParams(decl->type()); return; } - attributes_[&expr] = decl; - types_[&expr] = inference_context_->InstantiateTypeParams(decl->type()); + ReportMissingReference(expr, name); + types_[&expr] = ErrorType(); } void ResolveVisitor::ResolveQualifiedIdentifier( @@ -1021,18 +1053,28 @@ void ResolveVisitor::ResolveQualifiedIdentifier( return; } - const VariableDecl* absl_nullable decl = nullptr; - int segment_index_out = -1; - namespace_generator_.GenerateCandidates( - qualifiers, [&decl, &segment_index_out, this](absl::string_view candidate, - int segment_index) { - decl = LookupIdentifier(candidate); - if (decl != nullptr) { - segment_index_out = segment_index; - return false; - } - return true; - }); + // Local variables (comprehension, bind) are simple identifiers so we can + // skip generating the different namespace-qualified candidates. + const VariableDecl* local_decl = LookupLocalIdentifier(qualifiers[0]); + const VariableDecl* decl = nullptr; + + int matched_segment_index = -1; + + if (local_decl != nullptr && !absl::StartsWith(qualifiers[0], ".")) { + decl = local_decl; + matched_segment_index = 0; + } else { + namespace_generator_.GenerateCandidates( + qualifiers, [&decl, &matched_segment_index, this]( + absl::string_view candidate, int segment_index) { + decl = LookupGlobalIdentifier(candidate); + if (decl != nullptr) { + matched_segment_index = segment_index; + return false; + } + return true; + }); + } if (decl == nullptr) { ReportMissingReference(expr, FormatCandidate(qualifiers)); @@ -1040,7 +1082,8 @@ void ResolveVisitor::ResolveQualifiedIdentifier( return; } - const int num_select_opts = qualifiers.size() - segment_index_out - 1; + const int num_select_opts = qualifiers.size() - matched_segment_index - 1; + const Expr* root = &expr; std::vector select_opts; select_opts.reserve(num_select_opts); @@ -1049,7 +1092,9 @@ void ResolveVisitor::ResolveQualifiedIdentifier( root = &root->select_expr().operand(); } - attributes_[root] = decl; + attributes_[root] = {decl, + /* requires_disambiguation= */ decl != local_decl && + local_decl != nullptr}; types_[root] = inference_context_->InstantiateTypeParams(decl->type()); // fix-up select operations that were deferred. @@ -1196,13 +1241,18 @@ class ResolveRewriter : public AstRewriterBase { bool rewritten = false; if (auto iter = visitor_.attributes().find(&expr); iter != visitor_.attributes().end()) { - const VariableDecl* decl = iter->second; + const VariableDecl* decl = iter->second.decl; auto& ast_ref = reference_map_[expr.id()]; - ast_ref.set_name(decl->name()); + std::string name = decl->name(); + if (iter->second.requires_disambiguation && + !absl::StartsWith(name, ".")) { + name = absl::StrCat(".", name); + } + ast_ref.set_name(name); if (decl->has_value()) { ast_ref.set_value(decl->value()); } - expr.mutable_ident_expr().set_name(decl->name()); + expr.mutable_ident_expr().set_name(std::move(name)); rewritten = true; } else if (auto iter = visitor_.functions().find(&expr); iter != visitor_.functions().end()) { @@ -1211,7 +1261,6 @@ class ResolveRewriter : public AstRewriterBase { auto& ast_ref = reference_map_[expr.id()]; ast_ref.set_name(decl->name()); for (const auto& overload : decl->overloads()) { - // TODO(uncreated-issue/72): narrow based on type inferences and shape. ast_ref.mutable_overload_id().push_back(overload.id()); } expr.mutable_call_expr().set_function(decl->name()); diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index 0f07de75e..9e350d266 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -65,6 +65,7 @@ using ::testing::ElementsAre; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::IsEmpty; +using ::testing::Not; using ::testing::Pair; using ::testing::Property; using ::testing::SizeIs; @@ -750,18 +751,18 @@ TEST(TypeCheckerImplTest, NestedComprehensions) { EXPECT_THAT(result.GetIssues(), IsEmpty()); } -TEST(TypeCheckerImplTest, ComprehensionVarsFollowNamespacePriorityRules) { +TEST(TypeCheckerImplTest, ComprehensionVarsShadowNamespacePriorityRules) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.set_container("com"); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); - // Namespace resolution still applies, compre var doesn't shadow com.x + // Namespace compre var shadows com.x env.InsertVariableIfAbsent(MakeVariableDecl("com.x", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, - MakeTestParsedAst("['1', '2'].all(x, x == 2)")); + MakeTestParsedAst("['1', '2'].exists(x, x == '2')")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); @@ -769,20 +770,19 @@ TEST(TypeCheckerImplTest, ComprehensionVarsFollowNamespacePriorityRules) { EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->reference_map(), - Contains(Pair(_, IsVariableReference("com.x")))); + Not(Contains(Pair(_, IsVariableReference("com.x"))))); } -TEST(TypeCheckerImplTest, ComprehensionVarsFollowQualifiedIdentPriority) { +TEST(TypeCheckerImplTest, ComprehensionVarsShadowsQualifiedIdent) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); - // Namespace resolution still applies, compre var doesn't shadow x.y env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, - MakeTestParsedAst("[{'y': '2'}].all(x, x.y == 2)")); + MakeTestParsedAst("[{'y': '2'}].all(x, x.y == '2')")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); @@ -790,7 +790,46 @@ TEST(TypeCheckerImplTest, ComprehensionVarsFollowQualifiedIdentPriority) { EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->reference_map(), - Contains(Pair(_, IsVariableReference("x.y")))); + Not(Contains(Pair(_, IsVariableReference("x.y"))))); +} + +TEST(TypeCheckerImplTest, ComprehensionVarsDisamgiguatesQualifiedIdent) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("[{'y': 0}].all(x, .x.y == 2)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->reference_map(), + Contains(Pair(_, IsVariableReference(".x.y")))); +} + +TEST(TypeCheckerImplTest, ComprehensionVarsDisamgiguatesIdent) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("['foo'].all(x, .x == 2)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + + EXPECT_THAT(result.GetIssues(), IsEmpty()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->reference_map(), + Contains(Pair(_, IsVariableReference(".x")))); } TEST(TypeCheckerImplTest, ComprehensionVarsCyclicParamAssignability) {