Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions checker/internal/type_check_env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <cstddef>
#include <cstdint>
#include <string>

#include "absl/base/nullability.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -134,7 +133,7 @@ absl::StatusOr<absl::optional<VariableDecl>> TypeCheckEnv::LookupTypeConstant(
google::protobuf::Arena* absl_nonnull arena, absl::string_view name) const {
CEL_ASSIGN_OR_RETURN(absl::optional<Type> type, LookupTypeName(name));
if (type.has_value()) {
return MakeVariableDecl(std::string(type->name()), TypeType(arena, *type));
return MakeVariableDecl(type->name(), TypeType(arena, *type));
}

if (name.find('.') != name.npos) {
Expand Down Expand Up @@ -185,7 +184,7 @@ absl::StatusOr<absl::optional<StructTypeField>> TypeCheckEnv::LookupStructField(
return absl::nullopt;
}

const VariableDecl* absl_nullable VariableScope::LookupVariable(
const VariableDecl* absl_nullable VariableScope::LookupLocalVariable(
absl::string_view name) const {
const VariableScope* scope = this;
while (scope != nullptr) {
Expand All @@ -194,8 +193,7 @@ const VariableDecl* absl_nullable VariableScope::LookupVariable(
}
scope = scope->parent_;
}

return env_->LookupVariable(name);
return nullptr;
}

} // namespace cel::checker_internal
22 changes: 8 additions & 14 deletions checker/internal/type_check_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,11 @@ class TypeCheckEnv;
// Helper class for managing nested scopes and the local variables they
// implicitly declare.
//
// Nested scopes have a lifetime dependency on any parent scopes and the
// parent Type environment. Nested scopes should generally be managed by
// unique_ptrs.
// Nested scopes have a lifetime dependency on any parent scopes and should
// generally be managed by unique_ptrs.
class VariableScope {
public:
explicit VariableScope(const TypeCheckEnv& env ABSL_ATTRIBUTE_LIFETIME_BOUND)
: env_(&env), parent_(nullptr) {}
explicit VariableScope() : parent_(nullptr) {}

VariableScope(const VariableScope&) = delete;
VariableScope& operator=(const VariableScope&) = delete;
Expand All @@ -61,18 +59,17 @@ class VariableScope {

std::unique_ptr<VariableScope> MakeNestedScope() const
ABSL_ATTRIBUTE_LIFETIME_BOUND {
return absl::WrapUnique(new VariableScope(*env_, this));
return absl::WrapUnique(new VariableScope(this));
}

const VariableDecl* absl_nullable LookupVariable(
const VariableDecl* absl_nullable LookupLocalVariable(
absl::string_view name) const;

private:
VariableScope(const TypeCheckEnv& env ABSL_ATTRIBUTE_LIFETIME_BOUND,
const VariableScope* parent ABSL_ATTRIBUTE_LIFETIME_BOUND)
: env_(&env), parent_(parent) {}
explicit VariableScope(
const VariableScope* parent ABSL_ATTRIBUTE_LIFETIME_BOUND)
: parent_(parent) {}

const TypeCheckEnv* absl_nonnull env_;
const VariableScope* absl_nullable parent_;
absl::flat_hash_map<std::string, VariableDecl> variables_;
};
Expand Down Expand Up @@ -190,9 +187,6 @@ class TypeCheckEnv {
TypeCheckEnv MakeExtendedEnvironment() const ABSL_ATTRIBUTE_LIFETIME_BOUND {
return TypeCheckEnv(this);
}
VariableScope MakeVariableScope() const ABSL_ATTRIBUTE_LIFETIME_BOUND {
return VariableScope(*this);
}

const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const {
return descriptor_pool_.get();
Expand Down
115 changes: 82 additions & 33 deletions checker/internal/type_checker_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
Expand Down Expand Up @@ -234,6 +235,11 @@ class ResolveVisitor : public AstVisitorBase {
bool namespace_rewrite;
};

struct AttributeResolution {
const VariableDecl* decl;
bool requires_disambiguation;
};

ResolveVisitor(absl::string_view container,
NamespaceGenerator namespace_generator,
const TypeCheckEnv& env, const Ast& ast,
Expand All @@ -246,7 +252,7 @@ class ResolveVisitor : public AstVisitorBase {
inference_context_(&inference_context),
issues_(&issues),
ast_(&ast),
root_scope_(env.MakeVariableScope()),
root_scope_(),
arena_(arena),
current_scope_(&root_scope_) {}

Expand Down Expand Up @@ -294,7 +300,7 @@ class ResolveVisitor : public AstVisitorBase {
return functions_;
}

const absl::flat_hash_map<const Expr*, const VariableDecl*>& attributes()
const absl::flat_hash_map<const Expr*, AttributeResolution>& attributes()
const {
return attributes_;
}
Expand Down Expand Up @@ -344,9 +350,13 @@ class ResolveVisitor : public AstVisitorBase {
absl::string_view function_name,
int arg_count, bool is_receiver);

// Resolves the function call shape (i.e. the number of arguments and call
// style) for the given function call.
const VariableDecl* absl_nullable LookupIdentifier(absl::string_view name);
// Resolves a global identifier (i.e. declared in the CEL environment).
const VariableDecl* absl_nullable LookupGlobalIdentifier(
absl::string_view name);

// Resolves a local identifier (i.e. a bind or comrprehension var).
const VariableDecl* absl_nullable LookupLocalIdentifier(
absl::string_view name);

// Resolves the applicable function overloads for the given function call.
//
Expand Down Expand Up @@ -476,7 +486,7 @@ class ResolveVisitor : public AstVisitorBase {

// References that were resolved and may require AST rewrites.
absl::flat_hash_map<const Expr*, FunctionResolution> functions_;
absl::flat_hash_map<const Expr*, const VariableDecl*> attributes_;
absl::flat_hash_map<const Expr*, AttributeResolution> attributes_;
absl::flat_hash_map<const Expr*, std::string> struct_types_;

absl::flat_hash_map<const Expr*, Type> types_;
Expand Down Expand Up @@ -967,10 +977,20 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr,
types_[&expr] = resolution->result_type;
}

const VariableDecl* absl_nullable ResolveVisitor::LookupIdentifier(
const VariableDecl* absl_nullable ResolveVisitor::LookupLocalIdentifier(
absl::string_view name) {
if (const VariableDecl* decl = current_scope_->LookupVariable(name);
decl != nullptr) {
// Note: if we see a leading dot, this shouldn't resolve to a local variable,
// but we need to check whether we need to disambiguate against a global in
// the reference map.
if (absl::StartsWith(name, ".")) {
name = name.substr(1);
}
return current_scope_->LookupLocalVariable(name);
}

const VariableDecl* absl_nullable ResolveVisitor::LookupGlobalIdentifier(
absl::string_view name) {
if (const VariableDecl* decl = env_->LookupVariable(name); decl != nullptr) {
return decl;
}
absl::StatusOr<absl::optional<VariableDecl>> constant =
Expand All @@ -996,22 +1016,34 @@ const VariableDecl* absl_nullable ResolveVisitor::LookupIdentifier(

void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr,
absl::string_view name) {
// Local variables (comprehension, bind) are simple identifiers so we can
// skip generating the different namespace-qualified candidates.
const VariableDecl* local_decl = LookupLocalIdentifier(name);

if (local_decl != nullptr && !absl::StartsWith(name, ".")) {
attributes_[&expr] = {local_decl, false};
types_[&expr] =
inference_context_->InstantiateTypeParams(local_decl->type());
return;
}

const VariableDecl* decl = nullptr;
namespace_generator_.GenerateCandidates(
name, [&decl, this](absl::string_view candidate) {
decl = LookupIdentifier(candidate);
decl = LookupGlobalIdentifier(candidate);
// continue searching.
return decl == nullptr;
});

if (decl == nullptr) {
ReportMissingReference(expr, name);
types_[&expr] = ErrorType();
if (decl != nullptr) {
attributes_[&expr] = {decl,
/* requires_disambiguation= */ local_decl != nullptr};
types_[&expr] = inference_context_->InstantiateTypeParams(decl->type());
return;
}

attributes_[&expr] = decl;
types_[&expr] = inference_context_->InstantiateTypeParams(decl->type());
ReportMissingReference(expr, name);
types_[&expr] = ErrorType();
}

void ResolveVisitor::ResolveQualifiedIdentifier(
Expand All @@ -1021,26 +1053,37 @@ void ResolveVisitor::ResolveQualifiedIdentifier(
return;
}

const VariableDecl* absl_nullable decl = nullptr;
int segment_index_out = -1;
namespace_generator_.GenerateCandidates(
qualifiers, [&decl, &segment_index_out, this](absl::string_view candidate,
int segment_index) {
decl = LookupIdentifier(candidate);
if (decl != nullptr) {
segment_index_out = segment_index;
return false;
}
return true;
});
// Local variables (comprehension, bind) are simple identifiers so we can
// skip generating the different namespace-qualified candidates.
const VariableDecl* local_decl = LookupLocalIdentifier(qualifiers[0]);
const VariableDecl* decl = nullptr;

int matched_segment_index = -1;

if (local_decl != nullptr && !absl::StartsWith(qualifiers[0], ".")) {
decl = local_decl;
matched_segment_index = 0;
} else {
namespace_generator_.GenerateCandidates(
qualifiers, [&decl, &matched_segment_index, this](
absl::string_view candidate, int segment_index) {
decl = LookupGlobalIdentifier(candidate);
if (decl != nullptr) {
matched_segment_index = segment_index;
return false;
}
return true;
});
}

if (decl == nullptr) {
ReportMissingReference(expr, FormatCandidate(qualifiers));
types_[&expr] = ErrorType();
return;
}

const int num_select_opts = qualifiers.size() - segment_index_out - 1;
const int num_select_opts = qualifiers.size() - matched_segment_index - 1;

const Expr* root = &expr;
std::vector<const Expr*> select_opts;
select_opts.reserve(num_select_opts);
Expand All @@ -1049,7 +1092,9 @@ void ResolveVisitor::ResolveQualifiedIdentifier(
root = &root->select_expr().operand();
}

attributes_[root] = decl;
attributes_[root] = {decl,
/* requires_disambiguation= */ decl != local_decl &&
local_decl != nullptr};
types_[root] = inference_context_->InstantiateTypeParams(decl->type());

// fix-up select operations that were deferred.
Expand Down Expand Up @@ -1196,13 +1241,18 @@ class ResolveRewriter : public AstRewriterBase {
bool rewritten = false;
if (auto iter = visitor_.attributes().find(&expr);
iter != visitor_.attributes().end()) {
const VariableDecl* decl = iter->second;
const VariableDecl* decl = iter->second.decl;
auto& ast_ref = reference_map_[expr.id()];
ast_ref.set_name(decl->name());
std::string name = decl->name();
if (iter->second.requires_disambiguation &&
!absl::StartsWith(name, ".")) {
name = absl::StrCat(".", name);
}
ast_ref.set_name(name);
if (decl->has_value()) {
ast_ref.set_value(decl->value());
}
expr.mutable_ident_expr().set_name(decl->name());
expr.mutable_ident_expr().set_name(std::move(name));
rewritten = true;
} else if (auto iter = visitor_.functions().find(&expr);
iter != visitor_.functions().end()) {
Expand All @@ -1211,7 +1261,6 @@ class ResolveRewriter : public AstRewriterBase {
auto& ast_ref = reference_map_[expr.id()];
ast_ref.set_name(decl->name());
for (const auto& overload : decl->overloads()) {
// TODO(uncreated-issue/72): narrow based on type inferences and shape.
ast_ref.mutable_overload_id().push_back(overload.id());
}
expr.mutable_call_expr().set_function(decl->name());
Expand Down
Loading