diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 1e9995b19..28cbf21e0 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -48,7 +48,6 @@ #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" -#include "common/source.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/status_macros.h" @@ -66,43 +65,6 @@ std::string FormatCandidate(absl::Span qualifiers) { return absl::StrJoin(qualifiers, "."); } -SourceLocation ComputeSourceLocation(const Ast& ast, int64_t expr_id) { - const auto& source_info = ast.source_info(); - auto iter = source_info.positions().find(expr_id); - if (iter == source_info.positions().end()) { - return SourceLocation{}; - } - int32_t absolute_position = iter->second; - if (absolute_position < 0) { - return SourceLocation{}; - } - - // Find the first line offset that is greater than the absolute position. - int32_t line_idx = -1; - int32_t offset = 0; - for (int32_t i = 0; i < source_info.line_offsets().size(); ++i) { - int32_t next_offset = source_info.line_offsets()[i]; - if (next_offset <= offset) { - // Line offset is not monotonically increasing, so line information is - // invalid. - return SourceLocation{}; - } - if (absolute_position < next_offset) { - line_idx = i; - break; - } - offset = next_offset; - } - - if (line_idx < 0 || line_idx >= source_info.line_offsets().size()) { - return SourceLocation{}; - } - - int32_t rel_position = absolute_position - offset; - - return SourceLocation{line_idx + 1, rel_position}; -} - // Flatten the type to the AST type representation to remove any lifecycle // dependency between the type check environment and the AST. // @@ -362,7 +324,7 @@ class ResolveVisitor : public AstVisitorBase { void ReportMissingReference(const Expr& expr, absl::string_view name) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("undeclared reference to '", name, "' (in container '", container_, "')"))); } @@ -370,7 +332,7 @@ class ResolveVisitor : public AstVisitorBase { void ReportUndefinedField(int64_t expr_id, absl::string_view field_name, absl::string_view struct_name) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr_id), + ast_->ComputeSourceLocation(expr_id), absl::StrCat("undefined field '", field_name, "' not found in struct '", struct_name, "'"))); } @@ -378,7 +340,7 @@ class ResolveVisitor : public AstVisitorBase { void ReportTypeMismatch(int64_t expr_id, const Type& expected, const Type& actual) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr_id), + ast_->ComputeSourceLocation(expr_id), absl::StrCat("expected type '", FormatTypeName(inference_context_->FinalizeType(expected)), "' but found '", @@ -408,7 +370,7 @@ class ResolveVisitor : public AstVisitorBase { } if (!inference_context_->IsAssignable(value_type, field_type)) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, field.id()), + ast_->ComputeSourceLocation(field.id()), absl::StrCat( "expected type of field '", field_info->name(), "' is '", FormatTypeName(inference_context_->FinalizeType(field_type)), @@ -553,7 +515,7 @@ void ResolveVisitor::PostVisitConst(const Expr& expr, break; default: ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("unsupported constant type: ", constant.kind().index()))); types_[&expr] = ErrorType(); @@ -605,7 +567,7 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { // To match the Go implementation, we just warn here, but in the future // we should consider making this an error. ReportIssue(TypeCheckIssue( - Severity::kWarning, ComputeSourceLocation(*ast_, key->id()), + Severity::kWarning, ast_->ComputeSourceLocation(key->id()), absl::StrCat( "unsupported map key type: ", FormatTypeName(inference_context_->FinalizeType(key_type))))); @@ -711,7 +673,7 @@ void ResolveVisitor::PostVisitStruct(const Expr& expr, if (resolved_type.kind() != TypeKind::kStruct && !IsWellKnownMessageType(resolved_name)) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("type '", resolved_name, "' does not support message creation"))); types_[&expr] = ErrorType(); @@ -862,7 +824,7 @@ void ResolveVisitor::PostVisitComprehensionSubexpression( break; default: ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, comprehension.iter_range().id()), + ast_->ComputeSourceLocation(comprehension.iter_range().id()), absl::StrCat( "expression of type '", FormatTypeName(inference_context_->FinalizeType(range_type)), @@ -933,7 +895,7 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, if (!resolution.has_value()) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("found no matching overload for '", decl.name(), "' applied to '(", absl::StrJoin(arg_types, ", ", @@ -1133,7 +1095,7 @@ absl::optional ResolveVisitor::CheckFieldType(int64_t id, } ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, id), + ast_->ComputeSourceLocation(id), absl::StrCat( "expression of type '", FormatTypeName(inference_context_->FinalizeType(operand_type)), diff --git a/common/BUILD b/common/BUILD index 8dd8921cc..e289ef413 100644 --- a/common/BUILD +++ b/common/BUILD @@ -25,6 +25,7 @@ cc_library( hdrs = ["ast.h"], deps = [ ":expr", + ":source", "//common/ast:metadata", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", @@ -39,6 +40,7 @@ cc_test( deps = [ ":ast", ":expr", + ":source", "//internal:testing", "@com_google_absl//absl/container:flat_hash_map", ], diff --git a/common/ast.cc b/common/ast.cc index aea153197..48b6f5e0b 100644 --- a/common/ast.cc +++ b/common/ast.cc @@ -19,6 +19,7 @@ #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "common/ast/metadata.h" +#include "common/source.h" namespace cel { namespace { @@ -57,4 +58,41 @@ const Reference* absl_nullable Ast::GetReference(int64_t expr_id) const { return &iter->second; } +SourceLocation Ast::ComputeSourceLocation(int64_t expr_id) const { + const auto& source_info = this->source_info(); + auto iter = source_info.positions().find(expr_id); + if (iter == source_info.positions().end()) { + return SourceLocation{}; + } + int32_t absolute_position = iter->second; + if (absolute_position < 0) { + return SourceLocation{}; + } + + // Find the first line offset that is greater than the absolute position. + int32_t line_idx = -1; + int32_t offset = 0; + for (int32_t i = 0; i < source_info.line_offsets().size(); ++i) { + int32_t next_offset = source_info.line_offsets()[i]; + if (next_offset <= offset) { + // Line offset is not monotonically increasing, so line information is + // invalid. + return SourceLocation{}; + } + if (absolute_position < next_offset) { + line_idx = i; + break; + } + offset = next_offset; + } + + if (line_idx < 0 || line_idx >= source_info.line_offsets().size()) { + return SourceLocation{}; + } + + int32_t rel_position = absolute_position - offset; + + return SourceLocation{line_idx + 1, rel_position}; +} + } // namespace cel diff --git a/common/ast.h b/common/ast.h index 1b07b9878..db336f52d 100644 --- a/common/ast.h +++ b/common/ast.h @@ -24,6 +24,7 @@ #include "absl/strings/string_view.h" #include "common/ast/metadata.h" // IWYU pragma: export #include "common/expr.h" +#include "common/source.h" namespace cel { @@ -135,6 +136,13 @@ class Ast final { expr_version_ = expr_version; } + // Computes the source location (line and column) for the given expression id + // from the source info (which stores absolute positions). + // + // Returns a default (empty) source location if the expression id is not found + // or the source info is not populated correctly. + SourceLocation ComputeSourceLocation(int64_t expr_id) const; + private: Expr root_expr_; SourceInfo source_info_; diff --git a/common/ast_test.cc b/common/ast_test.cc index 744b9e8d3..56e1bcd1e 100644 --- a/common/ast_test.cc +++ b/common/ast_test.cc @@ -18,6 +18,7 @@ #include "absl/container/flat_hash_map.h" #include "common/expr.h" +#include "common/source.h" #include "internal/testing.h" namespace cel { @@ -132,5 +133,56 @@ TEST(AstImpl, CheckedExprDeepCopy) { EXPECT_EQ(ast.source_info().syntax_version(), "1.0"); } +TEST(AstImpl, ComputeSourceLocation) { + SourceInfo source_info; + source_info.set_line_offsets({10, 20, 30}); + source_info.mutable_positions()[1] = 0; // Start of first line + source_info.mutable_positions()[2] = 5; // Middle of first line + source_info.mutable_positions()[3] = 10; // ... + source_info.mutable_positions()[4] = 15; + source_info.mutable_positions()[5] = 20; + source_info.mutable_positions()[6] = 25; + + Ast ast(Expr{}, std::move(source_info)); + + EXPECT_EQ(ast.ComputeSourceLocation(1), (SourceLocation{1, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(2), (SourceLocation{1, 5})); + EXPECT_EQ(ast.ComputeSourceLocation(3), (SourceLocation{2, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(4), (SourceLocation{2, 5})); + EXPECT_EQ(ast.ComputeSourceLocation(5), (SourceLocation{3, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(6), (SourceLocation{3, 5})); +} + +TEST(AstImpl, ComputeSourceLocationFailures) { + SourceInfo source_info; + source_info.set_line_offsets({10, 20}); + source_info.mutable_positions()[1] = -1; // Negative position + source_info.mutable_positions()[2] = 25; // Beyond last line offset + // ID 3 is missing + + Ast ast; + ast.mutable_source_info() = std::move(source_info); + + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + EXPECT_EQ(ast.ComputeSourceLocation(2), SourceLocation{}); + EXPECT_EQ(ast.ComputeSourceLocation(3), SourceLocation{}); +} + +TEST(AstImpl, ComputeSourceLocationInvalidLineOffsets) { + { + // Empty line offsets + Ast ast; + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + } + { + // Non-monotonic + SourceInfo source_info; + source_info.set_line_offsets({10, 5}); + source_info.mutable_positions()[1] = 12; + Ast ast(Expr{}, std::move(source_info)); + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + } +} + } // namespace } // namespace cel