diff --git a/tensorflow_text/BUILD b/tensorflow_text/BUILD index 209285f54..d0331dbb7 100644 --- a/tensorflow_text/BUILD +++ b/tensorflow_text/BUILD @@ -5,11 +5,15 @@ load("//tensorflow_text:tftext.bzl", "extra_py_deps", "if_pywrap", "py_library", # [internal] load build_test.bzl load("//tools/build_defs/license:license.bzl", "license") +load("//tools/build_defs/testing:bzl_library.bzl", "bzl_library") # Visibility rules package( default_applicable_licenses = [":license"], - default_visibility = ["//visibility:public"], + default_visibility = [ + "//visibility:public", + "@org_tensorflow//tensorflow/core/kernels/text:__subpackages__", + ], ) license(name = "license") @@ -1728,3 +1732,18 @@ py_test( "//tensorflow_text/core/pybinds:pybinds_library", ]), ) + +bzl_library( + name = "tftext_bzl", + srcs = ["tftext.bzl"], + parse_tests = False, + visibility = ["//visibility:private"], + deps = [ + "//devtools/build_cleaner/skylark:build_defs_lib", + "//third_party/bazel_rules/rules_python/python:py_library_bzl", + "//third_party/gpus/cuda:build_defs_bzl", + "//third_party/pybind11/google3_utils:build_defs_bzl", + "//third_party/tensorflow:tensorflow_bzl", + "@rules_cc//cc:core_rules", + ], +) diff --git a/tensorflow_text/core/kernels/BUILD b/tensorflow_text/core/kernels/BUILD index 8aba03626..73cd23865 100644 --- a/tensorflow_text/core/kernels/BUILD +++ b/tensorflow_text/core/kernels/BUILD @@ -1,1041 +1,294 @@ -"""Kernels for tf.text ops.""" +"""Kernels for tf.text ops. +All implementation files moved to //third_party/tensorflow/core/kernels/text. +""" -load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@rules_cc//cc:cc_library.bzl", "cc_library") -load("@rules_cc//cc:cc_test.bzl", "cc_test") - -# Placeholder: load proto_library -load("//tensorflow_text:tftext.bzl", "tf_cc_library", "tflite_cc_library") -# [internal] load cc_proto_library.bzl licenses(["notice"]) -# Visibility rules -package(default_visibility = ["//visibility:public"]) +package( + default_applicable_licenses = ["//tensorflow_text:license"], + default_compatible_with = ["//buildenv/target:non_prod"], + default_visibility = ["//visibility:public"], +) exports_files(["LICENSE"]) +# Aliases to relocated targets + cc_library( - name = "boise_offset_converter", - srcs = ["boise_offset_converter.cc"], - hdrs = ["boise_offset_converter.h"], - deps = [ - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], + name = "wordpiece_tokenizer", + hdrs = ["wordpiece_tokenizer.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:wordpiece_tokenizer"], ) -cc_test( - name = "boise_offset_converter_test", - size = "small", - srcs = ["boise_offset_converter_test.cc"], - deps = [ - ":boise_offset_converter", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], +cc_library( + name = "boise_offset_converter", + hdrs = ["boise_offset_converter.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:boise_offset_converter"], ) -tf_cc_library( +cc_library( name = "boise_offset_converter_kernel", - srcs = ["boise_offset_converter_kernel.cc"], hdrs = ["boise_offset_converter_kernel.h"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":boise_offset_converter_kernel_template", - # lite/kernels/shim:tf_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:boise_offset_converter_kernel"], ) -tf_cc_library( +cc_library( name = "boise_offset_converter_kernel_template", hdrs = ["boise_offset_converter_kernel_template.h"], - tf_deps = [ - # tf/platform:tstring tensorflow dep, - ], - deps = [ - ":boise_offset_converter", - "@com_google_absl//absl/status", - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:shape tensorflow dep, - # lite/kernels/shim:status_macros tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:boise_offset_converter_kernel_template"], ) cc_library( name = "byte_splitter", - srcs = ["byte_splitter.cc"], hdrs = ["byte_splitter.h"], - deps = [ - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -cc_test( - name = "byte_splitter_test", - size = "small", - srcs = ["byte_splitter_test.cc"], - deps = [ - ":byte_splitter", - "@com_google_googletest//:gtest_main", - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:byte_splitter"], ) -tf_cc_library( +cc_library( name = "byte_splitter_kernel", - srcs = ["byte_splitter_kernel.cc"], hdrs = ["byte_splitter_kernel.h"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":byte_splitter_kernel_template", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - # lite/kernels/shim:tf_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:byte_splitter_kernel"], ) -tf_cc_library( +cc_library( name = "byte_splitter_kernel_template", hdrs = ["byte_splitter_kernel_template.h"], - tf_deps = [ - # tf/platform:tstring tensorflow dep, - ], - deps = [ - ":byte_splitter", - "@com_google_absl//absl/status", - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:shape tensorflow dep, - # lite/kernels/shim:status_macros tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:byte_splitter_kernel_template"], ) -tflite_cc_library( +cc_library( name = "byte_splitter_tflite", - srcs = ["byte_splitter_tflite.cc"], hdrs = ["byte_splitter_tflite.h"], - deps = [ - ":byte_splitter_kernel_template", - # lite:mutable_op_resolver tensorflow dep, - # lite/c:common tensorflow dep, - # lite/kernels/shim:tflite_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:byte_splitter_tflite"], ) -tf_cc_library( +cc_library( name = "constrained_sequence", - srcs = ["constrained_sequence.cc"], hdrs = ["constrained_sequence.h"], - tf_deps = [ - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - # tf:protos_all_cc tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:constrained_sequence"], ) -tf_cc_library( - name = "constrained_sequence_kernel", - srcs = ["constrained_sequence_kernel.cc"], - tf_deps = [ - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - # tf:protos_all_cc tensorflow dep, - ], - deps = [ - ":constrained_sequence", - "@com_google_absl//absl/base:core_headers", - ], -) - -cc_test( - name = "constrained_sequence_kernel_input_validation_test", - srcs = ["constrained_sequence_kernel_input_validation_test.cc"], - deps = [ - ":text_kernels_test_util", - "@com_google_googletest//:gtest_main", - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - # tf:protos_all_cc tensorflow dep, - # tf:test tensorflow dep, - # tf:testlib tensorflow dep, - # tf/kernels:ops_testutil tensorflow dep, - "//tensorflow_text:constrained_sequence_op_cc", - ], -) - -cc_test( - name = "exp_greedy_constrained_sequence_kernel_test", - srcs = ["exp_greedy_constrained_sequence_kernel_test.cc"], - deps = [ - ":text_kernels_test_util", - "@com_google_googletest//:gtest_main", - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - # tf:protos_all_cc tensorflow dep, - # tf:test tensorflow dep, - # tf:testlib tensorflow dep, - # tf/kernels:ops_testutil tensorflow dep, - "//tensorflow_text:constrained_sequence_op_cc", - ], -) - -cc_test( - name = "exp_viterbi_constrained_sequence_kernel_test", - srcs = ["exp_viterbi_constrained_sequence_kernel_test.cc"], - deps = [ - ":text_kernels_test_util", - "@com_google_googletest//:gtest_main", - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - # tf:protos_all_cc tensorflow dep, - # tf:test tensorflow dep, - # tf:testlib tensorflow dep, - # tf/kernels:ops_testutil tensorflow dep, - "//tensorflow_text:constrained_sequence_op_cc", - ], -) - -tf_cc_library( +cc_library( name = "fast_bert_normalizer", hdrs = ["fast_bert_normalizer.h"], - deps = [ - ":darts_clone_trie_builder", - ":darts_clone_trie_wrapper", - ":fast_bert_normalizer_model", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@icu//:common", - # lite/kernels/shim:status_macros tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:fast_bert_normalizer"], ) -flatbuffer_cc_library( +cc_library( name = "fast_bert_normalizer_model", - srcs = [ - "fast_bert_normalizer_model.fbs", - ], + hdrs = ["fast_bert_normalizer_model_generated.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:fast_bert_normalizer_model"], ) -tf_cc_library( +cc_library( name = "fast_bert_normalizer_model_builder", - srcs = ["fast_bert_normalizer_model_builder.cc"], hdrs = ["fast_bert_normalizer_model_builder.h"], - tf_deps = [ - # tf:lib tensorflow dep, - ], - deps = [ - ":darts_clone_trie_builder", - ":fast_bert_normalizer", - ":fast_bert_normalizer_model", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@icu//:common", - "@icu//:nfkc_cf", # Needed for NFKC_Casefold Unicode Normalization form. - "@com_googlesource_code_re2//:re2", - # lite/kernels/shim:status_macros tensorflow dep, - ], -) - -cc_test( - name = "fast_bert_normalizer_test", - size = "small", - srcs = ["fast_bert_normalizer_test.cc"], - deps = [ - ":fast_bert_normalizer", - ":fast_bert_normalizer_model_builder", - "@com_google_googletest//:gtest_main", - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:fast_bert_normalizer_model_builder"], ) cc_library( name = "fast_bert_normalizer_kernel_template", hdrs = ["fast_bert_normalizer_kernel_template.h"], - deps = [ - ":fast_bert_normalizer", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:status_macros tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:fast_bert_normalizer_kernel_template"], ) -tf_cc_library( +cc_library( name = "fast_bert_normalizer_tf_kernel", - srcs = ["fast_bert_normalizer_tf_kernel.cc"], hdrs = ["fast_bert_normalizer_tf_kernel.h"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":fast_bert_normalizer_kernel_template", - # lite/kernels/shim:tf_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:fast_bert_normalizer_tf_kernel"], ) -tflite_cc_library( +cc_library( name = "fast_bert_normalizer_tflite", - srcs = ["fast_bert_normalizer_tflite.cc"], hdrs = ["fast_bert_normalizer_tflite.h"], - deps = [ - ":fast_bert_normalizer_kernel_template", - # lite:mutable_op_resolver tensorflow dep, - # lite/c:common tensorflow dep, - # lite/kernels/shim:tflite_op_shim tensorflow dep, - ], -) - -cc_test( - name = "log_greedy_constrained_sequence_kernel_test", - srcs = ["log_greedy_constrained_sequence_kernel_test.cc"], - deps = [ - ":text_kernels_test_util", - "@com_google_googletest//:gtest_main", - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - # tf:protos_all_cc tensorflow dep, - # tf:test tensorflow dep, - # tf:testlib tensorflow dep, - # tf/kernels:ops_testutil tensorflow dep, - "//tensorflow_text:constrained_sequence_op_cc", - ], -) - -cc_test( - name = "log_viterbi_constrained_sequence_kernel_test", - srcs = ["log_viterbi_constrained_sequence_kernel_test.cc"], - deps = [ - ":text_kernels_test_util", - "@com_google_googletest//:gtest_main", - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - # tf:protos_all_cc tensorflow dep, - # tf:test tensorflow dep, - # tf:testlib tensorflow dep, - # tf/kernels:ops_testutil tensorflow dep, - "//tensorflow_text:constrained_sequence_op_cc", - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:fast_bert_normalizer_tflite"], ) cc_library( name = "darts_clone_trie_builder", - srcs = [ - "darts_clone_trie_builder.cc", - ], - hdrs = [ - "darts_clone_trie_builder.h", - ], - deps = [ - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@darts_clone", - ], + hdrs = ["darts_clone_trie_builder.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:darts_clone_trie_builder"], ) cc_library( name = "darts_clone_trie_wrapper", - hdrs = [ - "darts_clone_trie_wrapper.h", - ], - deps = [ - "@com_google_absl//absl/status:statusor", - ], -) - -cc_test( - name = "darts_clone_trie_test", - size = "small", - srcs = ["darts_clone_trie_test.cc"], - deps = [ - ":darts_clone_trie_builder", - ":darts_clone_trie_wrapper", - "@com_google_absl//absl/status", - "@com_google_googletest//:gtest_main", - ], + hdrs = ["darts_clone_trie_wrapper.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:darts_clone_trie_wrapper"], ) -tf_cc_library( +cc_library( name = "disjoint_set_forest", hdrs = ["disjoint_set_forest.h"], - tf_deps = [ - # tf:lib tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:disjoint_set_forest"], ) -cc_test( - name = "disjoint_set_forest_test", - size = "small", - srcs = ["disjoint_set_forest_test.cc"], - deps = [ - ":disjoint_set_forest", - "@com_google_googletest//:gtest_main", - ], -) - -tf_cc_library( +cc_library( name = "fast_wordpiece_tokenizer", - srcs = ["fast_wordpiece_tokenizer.cc"], - hdrs = [ - "fast_wordpiece_tokenizer.h", - ], - deps = [ - ":darts_clone_trie_wrapper", - ":fast_wordpiece_tokenizer_model", - ":fast_wordpiece_tokenizer_utils", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@icu//:nfkc", - # lite/kernels/shim:status_macros tensorflow dep, - ], -) - -cc_test( - name = "fast_wordpiece_tokenizer_test", - srcs = ["fast_wordpiece_tokenizer_test.cc"], - data = [ - "//tensorflow_text:python/ops/test_data/fast_wordpiece_tokenizer_model.fb", - "//tensorflow_text:python/ops/test_data/fast_wordpiece_tokenizer_model_ver_15_1.fb", - "//tensorflow_text:python/ops/test_data/fast_wordpiece_tokenizer_model_ver_16_0.fb", - ], - deps = [ - ":fast_wordpiece_tokenizer", - ":fast_wordpiece_tokenizer_model_builder", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/flags:flag", - "@icu//:headers", - # tf:lib tensorflow dep, - ], + hdrs = ["fast_wordpiece_tokenizer.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:fast_wordpiece_tokenizer"], ) -flatbuffer_cc_library( +cc_library( name = "fast_wordpiece_tokenizer_model", - srcs = [ - "fast_wordpiece_tokenizer_model.fbs", - ], + hdrs = ["fast_wordpiece_tokenizer_model_generated.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:fast_wordpiece_tokenizer_model"], ) -flatbuffer_cc_library( - name = "phrase_tokenizer_model", - srcs = [ - "phrase_tokenizer_model.fbs", - ], +cc_library( + name = "fast_wordpiece_tokenizer_model_builder", + hdrs = ["fast_wordpiece_tokenizer_model_builder.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:fast_wordpiece_tokenizer_model_builder"], ) -tf_cc_library( - name = "fast_wordpiece_tokenizer_model_builder", - srcs = ["fast_wordpiece_tokenizer_model_builder.cc"], - hdrs = [ - "fast_wordpiece_tokenizer_model_builder.h", - ], - deps = [ - ":darts_clone_trie_builder", - ":darts_clone_trie_wrapper", - ":fast_wordpiece_tokenizer_model", - ":fast_wordpiece_tokenizer_utils", - ":sentence_fragmenter_v2", - ":string_vocab", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@cppitertools", - "@icu//:nfkc", - # lite/kernels/shim:status_macros tensorflow dep, - ], +cc_library( + name = "phrase_tokenizer_model", + hdrs = ["phrase_tokenizer_model_generated.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:phrase_tokenizer_model"], ) -tf_cc_library( +cc_library( name = "phrase_tokenizer_model_builder", - srcs = ["phrase_tokenizer_model_builder.cc"], - hdrs = [ - "phrase_tokenizer_model_builder.h", - ], - deps = [ - ":darts_clone_trie_builder", - ":darts_clone_trie_wrapper", - ":fast_wordpiece_tokenizer_utils", - ":phrase_tokenizer_model", - ":sentence_fragmenter_v2", - ":string_vocab", - ":whitespace_tokenizer_config_builder", - ":wordpiece_tokenizer", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@icu//:nfkc", - # lite/kernels/shim:status_macros tensorflow dep, - "//tensorflow_text/core/kernels/sentencepiece:double_array_trie_builder", - ], + hdrs = ["phrase_tokenizer_model_builder.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:phrase_tokenizer_model_builder"], ) -tf_cc_library( +cc_library( name = "fast_wordpiece_tokenizer_kernel", - srcs = ["fast_wordpiece_tokenizer_kernel.cc"], hdrs = ["fast_wordpiece_tokenizer_kernel.h"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":fast_wordpiece_tokenizer_kernel_template", - # lite/kernels/shim:tf_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:fast_wordpiece_tokenizer_kernel"], ) cc_library( name = "fast_wordpiece_tokenizer_kernel_template", hdrs = ["fast_wordpiece_tokenizer_kernel_template.h"], - deps = [ - ":fast_wordpiece_tokenizer", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:status_macros tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:fast_wordpiece_tokenizer_kernel_template"], ) -tflite_cc_library( +cc_library( name = "fast_wordpiece_tokenizer_tflite", - srcs = ["fast_wordpiece_tokenizer_tflite.cc"], hdrs = ["fast_wordpiece_tokenizer_tflite.h"], - deps = [ - ":fast_wordpiece_tokenizer_kernel_template", - # lite:mutable_op_resolver tensorflow dep, - # lite/c:common tensorflow dep, - # lite/kernels/shim:tflite_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:fast_wordpiece_tokenizer_tflite"], ) cc_library( name = "fast_wordpiece_tokenizer_utils", - hdrs = [ - "fast_wordpiece_tokenizer_utils.h", - ], - deps = [ - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@icu//:nfkc", - ], + hdrs = ["fast_wordpiece_tokenizer_utils.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:fast_wordpiece_tokenizer_utils"], ) -cc_test( - name = "fast_wordpiece_tokenizer_utils_test", - srcs = ["fast_wordpiece_tokenizer_utils_test.cc"], - deps = [ - ":fast_wordpiece_tokenizer_utils", - "@com_google_googletest//:gtest_main", - ], -) - -tf_cc_library( - name = "mst_op_kernels", - srcs = ["mst_op_kernels.cc"], - tf_deps = [ - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - ], - deps = [ - ":mst_solver", - ], -) - -tf_cc_library( +cc_library( name = "mst_solver", hdrs = ["mst_solver.h"], - tf_deps = [ - # tf:lib tensorflow dep, - ], - deps = [ - ":disjoint_set_forest", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "mst_solver_test", - size = "small", - srcs = ["mst_solver_test.cc"], - deps = [ - ":mst_solver", - "@com_google_googletest//:gtest", - "@com_google_googletest//:gtest_main", - # tf:test tensorflow dep, - ], -) - -cc_test( - name = "mst_solver_random_comparison_test", - size = "small", - timeout = "long", - srcs = ["mst_solver_random_comparison_test.cc"], - tags = [ - "nofastbuild", # exclude from non-opt TAP projects - "optonly", # exclude from non-opt TAP projects - ], - deps = [ - ":mst_solver", - ":spanning_tree_iterator", - "@com_google_googletest//:gtest", # google-only - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/flags:flag", - # tf:lib tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:mst_solver"], ) -proto_library( - name = "edit_changes_proto", - srcs = ["edit_changes.proto"], -) - -cc_proto_library( - name = "edit_changes_cc_proto", - deps = [":edit_changes_proto"], -) - -tf_cc_library( +cc_library( name = "ngrams_kernel_template", hdrs = ["ngrams_kernel_template.h"], - tf_deps = [ - # tf/platform:tstring tensorflow dep, - ], - deps = [ - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:status_macros tensorflow dep, - # lite/kernels/shim:tensor_view tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:ngrams_kernel_template"], ) -tf_cc_library( +cc_library( name = "ngrams_kernel", - srcs = ["ngrams_kernel.cc"], hdrs = ["ngrams_kernel.h"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":ngrams_kernel_template", - # lite/kernels/shim:tf_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:ngrams_kernel"], ) -cc_test( - name = "ngrams_kernel_test", - srcs = ["ngrams_kernel_test.cc"], - deps = [ - # tf:framework tensorflow dep, - # tf:test tensorflow dep, - # tf:test_main tensorflow dep, - # tf/framework:shape_inference_testutil tensorflow dep, - # tf/framework:tensor_testutil tensorflow dep, - "//tensorflow_text:ngrams_op_cc", - ], -) - -tflite_cc_library( +cc_library( name = "ngrams_tflite", - srcs = ["ngrams_tflite.cc"], hdrs = ["ngrams_tflite.h"], - deps = [ - ":ngrams_kernel_template", - # lite:mutable_op_resolver tensorflow dep, - # lite/c:common tensorflow dep, - # lite/kernels/shim:tflite_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:ngrams_tflite"], ) -cc_test( - name = "ngrams_tflite_test", - srcs = ["ngrams_tflite_test.cc"], - deps = [ - ":ngrams_tflite", - "@com_google_googletest//:gtest_main", - "@flatbuffers", - # lite:string_util tensorflow dep, - # lite/c:common tensorflow dep, - # lite/kernels:test_util tensorflow dep, - # lite/schema:schema_fbs tensorflow dep, - ], -) - -tf_cc_library( - name = "normalize_kernels", - srcs = ["normalize_kernels.cc"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":edit_changes_cc_proto", - "@com_google_absl//absl/strings", - "@icu//:nfkc", - "@icu//:nfkc_cf", - ], -) - -tflite_cc_library( +cc_library( name = "ragged_tensor_to_tensor_tflite", - srcs = ["ragged_tensor_to_tensor_tflite.cc"], hdrs = ["ragged_tensor_to_tensor_tflite.h"], - deps = [ - "@flatbuffers", - # tf/util:ragged_to_dense_util_common tensorflow dep, - # lite:framework tensorflow dep, - # lite/c:common tensorflow dep, - # lite/kernels:kernel_util tensorflow dep, - # lite/kernels/internal:tensor tensorflow dep, - # lite/kernels/internal:types tensorflow dep, - ], -) - -cc_test( - name = "ragged_tensor_to_tensor_tflite_test", - srcs = ["ragged_tensor_to_tensor_tflite_test.cc"], - deps = [ - ":ragged_tensor_to_tensor_tflite", - "@com_google_googletest//:gtest_main", - "@flatbuffers", - # lite:framework tensorflow dep, - # lite/c:common tensorflow dep, - # lite/kernels:test_util tensorflow dep, - # lite/kernels/internal:tensor tensorflow dep, - # lite/schema:schema_fbs tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:ragged_tensor_to_tensor_tflite"], ) -tf_cc_library( +cc_library( name = "regex_split", - srcs = ["regex_split.cc"], hdrs = ["regex_split.h"], - deps = [ - "@com_google_absl//absl/strings", - "@com_googlesource_code_re2//:re2", - ], -) - -tf_cc_library( - name = "regex_split_kernels", - srcs = ["regex_split_kernels.cc"], - tf_deps = [ - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - ], - deps = [ - ":regex_split", - "@com_google_absl//absl/memory", - ], -) - -cc_test( - name = "regex_split_test", - srcs = ["regex_split_test.cc"], - deps = [ - ":regex_split", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/strings", - "@com_googlesource_code_re2//:re2", - # tf:lib tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:regex_split"], ) cc_library( name = "round_robin_trimmer", hdrs = ["round_robin_trimmer.h"], - deps = [ - ":trimmer", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "round_robin_trimmer_test", - size = "small", - srcs = ["round_robin_trimmer_test.cc"], - deps = [ - ":round_robin_trimmer", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - # tf:lib tensorflow dep, - # tf:test_main tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:round_robin_trimmer"], ) -tf_cc_library( +cc_library( name = "round_robin_trimmer_kernel", - srcs = ["round_robin_trimmer_kernel.cc"], hdrs = ["round_robin_trimmer_kernel.h"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":round_robin_trimmer_kernel_template", - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:tf_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:round_robin_trimmer_kernel"], ) -tf_cc_library( +cc_library( name = "round_robin_trimmer_kernel_template", hdrs = ["round_robin_trimmer_kernel_template.h"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":round_robin_trimmer", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:shape tensorflow dep, - # lite/kernels/shim:tensor_view tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:round_robin_trimmer_kernel_template"], ) -tflite_cc_library( +cc_library( name = "round_robin_trimmer_tflite", - srcs = ["round_robin_trimmer_tflite.cc"], hdrs = ["round_robin_trimmer_tflite.h"], - deps = [ - ":round_robin_trimmer_kernel_template", - # lite:mutable_op_resolver tensorflow dep, - # lite/c:common tensorflow dep, - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:tflite_op_shim tensorflow dep, - # lite/kernels/shim:tflite_op_wrapper tensorflow dep, - ], -) - -tf_cc_library( - name = "rouge_l_kernel", - srcs = ["rouge_l_kernel.cc"], - tf_deps = [ - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - ], - deps = [ - "@com_google_absl//absl/strings", - ], -) - -cc_test( - name = "rouge_l_kernel_test", - size = "small", - srcs = ["rouge_l_kernel_test.cc"], - deps = [ - ":rouge_l_kernel", - # tf:framework tensorflow dep, - # tf:test tensorflow dep, - # tf:test_main tensorflow dep, - # tf:testlib tensorflow dep, - # tf/kernels:ops_testutil tensorflow dep, - "//tensorflow_text:text_similarity_metric_ops_cc", - ], -) - -tf_cc_library( - name = "sentence_breaking_kernels", - srcs = ["sentence_breaking_kernels.cc"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":sentence_breaking_utils", - ":sentence_fragmenter", - "@com_google_absl//absl/strings", - "@icu//:common", - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:round_robin_trimmer_tflite"], ) -tf_cc_library( +cc_library( name = "sentence_breaking_utils", - srcs = ["sentence_breaking_utils.cc"], hdrs = ["sentence_breaking_utils.h"], - tf_deps = [ - # tf:lib tensorflow dep, - ], - deps = [ - "@com_google_absl//absl/strings", - "@icu//:common", - ], -) - -cc_test( - name = "sentence_breaking_utils_test", - size = "small", - srcs = ["sentence_breaking_utils_test.cc"], - deps = [ - ":sentence_breaking_utils", - "@com_google_googletest//:gtest", - "@com_google_googletest//:gtest_main", - "@icu//:common", - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:sentence_breaking_utils"], ) -tf_cc_library( +cc_library( name = "sentence_fragmenter", - srcs = ["sentence_fragmenter.cc"], hdrs = ["sentence_fragmenter.h"], - tf_deps = [ - # tf:lib tensorflow dep, - ], - deps = [ - ":sentence_breaking_utils", - "@com_google_absl//absl/status", - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:sentence_fragmenter"], ) -tf_cc_library( +cc_library( name = "sentence_fragmenter_v2", - srcs = ["sentence_fragmenter_v2.cc"], hdrs = ["sentence_fragmenter_v2.h"], - tf_deps = [ - # tf:lib tensorflow dep, - ], - deps = [ - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@icu//:common", - ], -) - -cc_test( - name = "sentence_fragmenter_v2_test", - srcs = ["sentence_fragmenter_v2_test.cc"], - deps = [ - ":sentence_fragmenter_v2", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@icu//:common", - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:sentence_fragmenter_v2"], ) -tf_cc_library( +cc_library( name = "sentence_fragmenter_v2_kernel", - srcs = ["sentence_fragmenter_v2_kernel.cc"], hdrs = ["sentence_fragmenter_v2_kernel.h"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":sentence_fragmenter_v2_kernel_template", - # lite/kernels/shim:tf_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:sentence_fragmenter_v2_kernel"], ) -tf_cc_library( +cc_library( name = "sentence_fragmenter_v2_kernel_template", hdrs = ["sentence_fragmenter_v2_kernel_template.h"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":sentence_fragmenter_v2", - "@com_google_absl//absl/status", - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:shape tensorflow dep, - # lite/kernels/shim:status_macros tensorflow dep, - # lite/kernels/shim:tensor_view tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:sentence_fragmenter_v2_kernel_template"], ) -tflite_cc_library( +cc_library( name = "sentence_fragmenter_v2_tflite", - srcs = ["sentence_fragmenter_v2_tflite.cc"], hdrs = ["sentence_fragmenter_v2_tflite.h"], - deps = [ - ":sentence_fragmenter_v2_kernel_template", - # lite:mutable_op_resolver tensorflow dep, - # lite/c:common tensorflow dep, - # lite/kernels/shim:tflite_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:sentence_fragmenter_v2_tflite"], ) -tf_cc_library( - name = "sentencepiece_kernels", - srcs = ["sentencepiece_kernels.cc"], - tf_deps = [ - # tf:framework tensorflow dep, - # tf:framework_headers_lib tensorflow dep, - # tf:lib tensorflow dep, - # tf:protos_all_cc tensorflow dep, - ], - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@com_google_sentencepiece//:sentencepiece_cc_proto", - "@com_google_sentencepiece//:sentencepiece_model_cc_proto", - "@com_google_sentencepiece//:sentencepiece_processor", - ], -) - -tf_cc_library( +cc_library( name = "spanning_tree_iterator", testonly = 1, - srcs = ["spanning_tree_iterator.cc"], hdrs = ["spanning_tree_iterator.h"], - tf_deps = [ - # tf:lib tensorflow dep, - ], -) - -cc_test( - name = "spanning_tree_iterator_test", - size = "small", - srcs = ["spanning_tree_iterator_test.cc"], - deps = [ - ":spanning_tree_iterator", - "@com_google_googletest//:gtest_main", - # tf:lib tensorflow dep, - ], -) - -tf_cc_library( - name = "split_merge_tokenize_kernel", - srcs = ["split_merge_tokenize_kernel.cc"], - tf_deps = [ - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - ], - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@icu//:common", - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:spanning_tree_iterator"], ) cc_library( name = "text_kernels_test_util", testonly = 1, - srcs = ["text_kernels_test_util.cc"], hdrs = ["text_kernels_test_util.h"], - deps = [ - "@com_google_googletest//:gtest", - # tf:framework tensorflow dep, - # tf:testlib tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:text_kernels_test_util"], ) -tflite_cc_library( +cc_library( name = "tflite_ops", hdrs = [ "byte_splitter_tflite.h", @@ -1047,362 +300,310 @@ tflite_cc_library( "sentence_fragmenter_v2_tflite.h", "utf8_binarize_tflite.h", "whitespace_tokenizer_tflite.h", - "//tensorflow_text/core/kernels/sentencepiece:sp_headers", - ], - deps = [ - ":byte_splitter_tflite", - ":fast_bert_normalizer_tflite", - ":fast_wordpiece_tokenizer_tflite", - ":ngrams_tflite", - ":ragged_tensor_to_tensor_tflite", - ":round_robin_trimmer_tflite", - ":sentence_fragmenter_v2_tflite", - ":utf8_binarize_tflite", - ":whitespace_tokenizer_tflite", - # lite:mutable_op_resolver tensorflow dep, - # lite/c:common tensorflow dep, - "//tensorflow_text/core/kernels/sentencepiece:py_tflite_registerer", - ], -) - -tf_cc_library( - name = "tokenizer_from_logits_kernel", - srcs = ["tokenizer_from_logits_kernel.cc"], - tf_deps = [ - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - ], - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/strings", - "@icu//:common", ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:tflite_ops"], ) cc_library( name = "trimmer", hdrs = ["trimmer.h"], - deps = [ - "@com_google_absl//absl/types:span", - ], -) - -tf_cc_library( - name = "unicode_script_tokenize_kernel", - srcs = ["unicode_script_tokenize_kernel.cc"], - tf_deps = [ - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - ], - deps = [ - "@icu//:common", - ], -) - -cc_test( - name = "unicode_script_tokenize_kernel_test", - srcs = ["unicode_script_tokenize_kernel_test.cc"], - deps = [ - ":text_kernels_test_util", - "@com_google_googletest//:gtest_main", - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - # tf:test tensorflow dep, - # tf:testlib tensorflow dep, - # tf/kernels:ops_testutil tensorflow dep, - "//tensorflow_text:unicode_script_tokenizer_cc", - ], -) - -tf_cc_library( - name = "whitespace_tokenize_kernel", - srcs = ["whitespace_tokenize_kernel.cc"], - tf_deps = [ - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - ], - deps = [ - "@icu//:common", - ], -) - -cc_test( - name = "whitespace_tokenize_kernel_test", - srcs = ["whitespace_tokenize_kernel_test.cc"], - deps = [ - ":text_kernels_test_util", - "@com_google_googletest//:gtest_main", - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - # tf:test tensorflow dep, - # tf:testlib tensorflow dep, - # tf/kernels:ops_testutil tensorflow dep, - "//tensorflow_text:whitespace_tokenizer_cc", - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:trimmer"], ) cc_library( name = "whitespace_tokenizer", - srcs = ["whitespace_tokenizer.cc"], hdrs = ["whitespace_tokenizer.h"], - deps = [ - "@com_google_absl//absl/strings", - "@icu//:common", - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:whitespace_tokenizer"], ) -cc_test( - name = "whitespace_tokenizer_test", - size = "small", - srcs = ["whitespace_tokenizer_test.cc"], - deps = [ - ":whitespace_tokenizer", - ":whitespace_tokenizer_config_builder", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - # tf:lib tensorflow dep, - # tf:test_main tensorflow dep, - ], -) - -tf_cc_library( +cc_library( name = "whitespace_tokenizer_kernel", - srcs = ["whitespace_tokenizer_kernel.cc"], hdrs = ["whitespace_tokenizer_kernel.h"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":whitespace_tokenizer_kernel_template", - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:tf_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:whitespace_tokenizer_kernel"], ) -tf_cc_library( +cc_library( name = "whitespace_tokenizer_kernel_template", hdrs = ["whitespace_tokenizer_kernel_template.h"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":whitespace_tokenizer", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:shape tensorflow dep, - # lite/kernels/shim:tensor_view tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:whitespace_tokenizer_kernel_template"], ) -tflite_cc_library( +cc_library( name = "whitespace_tokenizer_tflite", - srcs = ["whitespace_tokenizer_tflite.cc"], hdrs = ["whitespace_tokenizer_tflite.h"], - deps = [ - ":whitespace_tokenizer_kernel_template", - # lite:mutable_op_resolver tensorflow dep, - # lite/c:common tensorflow dep, - # lite/kernels/shim:tflite_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:whitespace_tokenizer_tflite"], ) cc_library( name = "whitespace_tokenizer_config_builder", - srcs = ["whitespace_tokenizer_config_builder.cc"], hdrs = ["whitespace_tokenizer_config_builder.h"], - deps = [ - "@icu//:common", - ], -) - -cc_test( - name = "whitespace_tokenizer_config_builder_test", - size = "small", - srcs = ["whitespace_tokenizer_config_builder_test.cc"], - deps = [ - ":whitespace_tokenizer", - ":whitespace_tokenizer_config_builder", - "@com_google_googletest//:gtest_main", - "@icu//:common", - # tf:lib tensorflow dep, - # tf:test_main tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:whitespace_tokenizer_config_builder"], ) cc_library( name = "string_vocab", - srcs = ["string_vocab.cc"], hdrs = ["string_vocab.h"], - deps = [ - ":wordpiece_tokenizer", - "@com_google_absl//absl/container:flat_hash_map", - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:string_vocab"], ) cc_library( name = "phrase_tokenizer", - srcs = ["phrase_tokenizer.cc"], hdrs = ["phrase_tokenizer.h"], - deps = [ - ":phrase_tokenizer_model", - ":string_vocab", - ":whitespace_tokenizer", - ":whitespace_tokenizer_config_builder", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/random", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - # lite/kernels/shim:status_macros tensorflow dep, - "//tensorflow_text/core/kernels/sentencepiece:double_array_trie", - ], -) - -cc_test( - name = "phrase_tokenizer_test", - size = "small", - srcs = ["phrase_tokenizer_test.cc"], - data = [ - "//tensorflow_text:python/ops/test_data/phrase_tokenizer_model.fb", - ], - deps = [ - ":phrase_tokenizer", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - # tf:lib tensorflow dep, - # tf:test_main tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:phrase_tokenizer"], ) cc_library( name = "phrase_tokenizer_kernel_template", hdrs = ["phrase_tokenizer_kernel_template.h"], - deps = [ - ":phrase_tokenizer", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:status_macros tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:phrase_tokenizer_kernel_template"], ) -tf_cc_library( +cc_library( name = "phrase_tokenizer_kernel", - srcs = ["phrase_tokenizer_kernel.cc"], hdrs = ["phrase_tokenizer_kernel.h"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":phrase_tokenizer_kernel_template", - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:tf_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:phrase_tokenizer_kernel"], ) cc_library( name = "utf8_binarize", - srcs = ["utf8_binarize.cc"], hdrs = ["utf8_binarize.h"], - deps = [ - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@icu//:common", - ], -) - -cc_test( - name = "utf8_binarize_test", - size = "small", - srcs = ["utf8_binarize_test.cc"], - deps = [ - ":utf8_binarize", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/types:span", - # tf:lib tensorflow dep, - # tf:test_main tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:utf8_binarize"], ) cc_library( name = "utf8_binarize_kernel_template", hdrs = ["utf8_binarize_kernel_template.h"], - deps = [ - ":utf8_binarize", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - # tf/platform:tstring tensorflow dep, - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:shape tensorflow dep, - # lite/kernels/shim:status_macros tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:utf8_binarize_kernel_template"], ) -tf_cc_library( +cc_library( name = "utf8_binarize_kernel", - srcs = ["utf8_binarize_kernel.cc"], hdrs = ["utf8_binarize_kernel.h"], - tf_deps = [ - # tf:framework tensorflow dep, - ], - deps = [ - ":utf8_binarize_kernel_template", - # lite/kernels/shim:op_kernel tensorflow dep, - # lite/kernels/shim:tf_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:utf8_binarize_kernel"], ) -tflite_cc_library( +cc_library( name = "utf8_binarize_tflite", - srcs = ["utf8_binarize_tflite.cc"], hdrs = ["utf8_binarize_tflite.h"], - deps = [ - ":utf8_binarize_kernel_template", - # lite:mutable_op_resolver tensorflow dep, - # lite/c:common tensorflow dep, - # lite/kernels/shim:tflite_op_shim tensorflow dep, - ], + deps = ["@org_tensorflow//tensorflow/core/kernels/text:utf8_binarize_tflite"], ) -tf_cc_library( - name = "wordpiece_kernel", - srcs = ["wordpiece_kernel.cc"], - tf_deps = [ - # tf:framework tensorflow dep, - # tf:lib tensorflow dep, - ], - deps = [ - ":wordpiece_tokenizer", - "@com_google_absl//absl/base:core_headers", - ], +alias( + name = "boise_offset_converter_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:boise_offset_converter_test", ) -tf_cc_library( - name = "wordpiece_tokenizer", - srcs = ["wordpiece_tokenizer.cc"], - hdrs = ["wordpiece_tokenizer.h"], - deps = [ - "@com_google_absl//absl/strings", - "@icu//:common", - ], +alias( + name = "byte_splitter_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:byte_splitter_test", +) + +alias( + name = "constrained_sequence_kernel", + actual = "@org_tensorflow//tensorflow/core/kernels/text:constrained_sequence_kernel", ) -cc_test( +alias( + name = "constrained_sequence_kernel_input_validation_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:constrained_sequence_kernel_input_validation_test", +) + +alias( + name = "exp_greedy_constrained_sequence_kernel_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:exp_greedy_constrained_sequence_kernel_test", +) + +alias( + name = "exp_viterbi_constrained_sequence_kernel_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:exp_viterbi_constrained_sequence_kernel_test", +) + +alias( + name = "fast_bert_normalizer_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:fast_bert_normalizer_test", +) + +alias( + name = "log_greedy_constrained_sequence_kernel_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:log_greedy_constrained_sequence_kernel_test", +) + +alias( + name = "log_viterbi_constrained_sequence_kernel_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:log_viterbi_constrained_sequence_kernel_test", +) + +alias( + name = "darts_clone_trie_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:darts_clone_trie_test", +) + +alias( + name = "disjoint_set_forest_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:disjoint_set_forest_test", +) + +alias( + name = "fast_wordpiece_tokenizer_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:fast_wordpiece_tokenizer_test", +) + +alias( + name = "fast_wordpiece_tokenizer_utils_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:fast_wordpiece_tokenizer_utils_test", +) + +alias( + name = "mst_op_kernels", + actual = "@org_tensorflow//tensorflow/core/kernels/text:mst_op_kernels", +) + +alias( + name = "mst_solver_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:mst_solver_test", +) + +alias( + name = "mst_solver_random_comparison_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:mst_solver_random_comparison_test", +) + +alias( + name = "edit_changes_proto", + actual = "@org_tensorflow//tensorflow/core/kernels/text:edit_changes_proto", +) + +alias( + name = "edit_changes_cc_proto", + actual = "@org_tensorflow//tensorflow/core/kernels/text:edit_changes_cc_proto", +) + +alias( + name = "ngrams_kernel_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:ngrams_kernel_test", +) + +alias( + name = "ngrams_tflite_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:ngrams_tflite_test", +) + +alias( + name = "normalize_kernels", + actual = "@org_tensorflow//tensorflow/core/kernels/text:normalize_kernels", +) + +alias( + name = "ragged_tensor_to_tensor_tflite_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:ragged_tensor_to_tensor_tflite_test", +) + +alias( + name = "regex_split_kernels", + actual = "@org_tensorflow//tensorflow/core/kernels/text:regex_split_kernels", +) + +alias( + name = "regex_split_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:regex_split_test", +) + +alias( + name = "round_robin_trimmer_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:round_robin_trimmer_test", +) + +alias( + name = "rouge_l_kernel", + actual = "@org_tensorflow//tensorflow/core/kernels/text:rouge_l_kernel", +) + +alias( + name = "rouge_l_kernel_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:rouge_l_kernel_test", +) + +alias( + name = "sentence_breaking_kernels", + actual = "@org_tensorflow//tensorflow/core/kernels/text:sentence_breaking_kernels", +) + +alias( + name = "sentence_breaking_utils_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:sentence_breaking_utils_test", +) + +alias( + name = "sentence_fragmenter_v2_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:sentence_fragmenter_v2_test", +) + +alias( + name = "sentencepiece_kernels", + actual = "@org_tensorflow//tensorflow/core/kernels/text:sentencepiece_kernels", +) + +alias( + name = "spanning_tree_iterator_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:spanning_tree_iterator_test", +) + +alias( + name = "split_merge_tokenize_kernel", + actual = "@org_tensorflow//tensorflow/core/kernels/text:split_merge_tokenize_kernel", +) + +alias( + name = "tokenizer_from_logits_kernel", + actual = "@org_tensorflow//tensorflow/core/kernels/text:tokenizer_from_logits_kernel", +) + +alias( + name = "unicode_script_tokenize_kernel", + actual = "@org_tensorflow//tensorflow/core/kernels/text:unicode_script_tokenize_kernel", +) + +alias( + name = "unicode_script_tokenize_kernel_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:unicode_script_tokenize_kernel_test", +) + +alias( + name = "whitespace_tokenize_kernel", + actual = "@org_tensorflow//tensorflow/core/kernels/text:whitespace_tokenize_kernel", +) + +alias( + name = "whitespace_tokenize_kernel_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:whitespace_tokenize_kernel_test", +) + +alias( + name = "whitespace_tokenizer_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:whitespace_tokenizer_test", +) + +alias( + name = "whitespace_tokenizer_config_builder_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:whitespace_tokenizer_config_builder_test", +) + +alias( + name = "phrase_tokenizer_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:phrase_tokenizer_test", +) + +alias( + name = "utf8_binarize_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text:utf8_binarize_test", +) + +alias( + name = "wordpiece_kernel", + actual = "@org_tensorflow//tensorflow/core/kernels/text:wordpiece_kernel", +) + +alias( name = "wordpiece_kernel_test", - size = "small", - srcs = ["wordpiece_kernel_test.cc"], - deps = [ - ":wordpiece_kernel", - # tf:framework tensorflow dep, - # tf:test tensorflow dep, - # tf:test_main tensorflow dep, - # tf:testlib tensorflow dep, - # tf/kernels:ops_testutil tensorflow dep, - "//tensorflow_text:wordpiece_tokenizer_cc", - ], + actual = "@org_tensorflow//tensorflow/core/kernels/text:wordpiece_kernel_test", ) diff --git a/tensorflow_text/core/kernels/boise_offset_converter.cc b/tensorflow_text/core/kernels/boise_offset_converter.cc deleted file mode 100644 index ac306d3e3..000000000 --- a/tensorflow_text/core/kernels/boise_offset_converter.cc +++ /dev/null @@ -1,254 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/boise_offset_converter.h" - -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" - -namespace tensorflow { -namespace text { - -bool IsRightOutsideSpan(int token_start, int token_end, int span_start, - int span_end) { - // Token: |------) - // Span: |-----) - return token_start >= span_end; -} - -bool IsLeftOutsideSpan(int token_start, int token_end, int span_start, - int span_end) { - // Token: |------) - // Span: |-----) - return token_end <= span_start; -} - -bool IsStartOfSpan(int token_start, int token_end, int span_start, - int span_end) { - // Returns true if the token overlaps with the span from the - // left side (i.e. start) of the span, but not have the span inside. - // Token: |-------) - // Span: |-----) - return token_start <= span_start && token_end > span_start && - token_end <= span_end; -} - -bool IsEndOfSpan(int token_start, int token_end, int span_start, int span_end) { - // Returns true if the token overlaps with the span from the - // right side (i.e. end) of the span, but not have the span inside. - // Token: |------) - // Span: |------) - return token_start < span_end && token_end >= span_end && - token_start >= span_start; -} - -bool IsInsideSpan(int token_start, int token_end, int span_start, - int span_end) { - // Token: |------) - // Span: |-----------) - return token_start >= span_start && token_end <= span_end; -} - -absl::StatusOr> OffsetsToBoiseTags( - const std::vector& token_begin_offsets, - const std::vector& token_end_offsets, - const std::vector& span_begin_offsets, - const std::vector& span_end_offsets, - const std::vector& span_type, - const bool use_strict_boundary_mode) { - // Verify that token vectors are all the same size - if (token_begin_offsets.size() != token_end_offsets.size()) { - return absl::InvalidArgumentError("Token offsets must have the same size"); - } - if (span_begin_offsets.size() != span_end_offsets.size() || - span_begin_offsets.size() != span_type.size()) { - return absl::InvalidArgumentError("Span offsets must have the same size"); - } - - // Iterate through tokens - std::vector results; - int span_index = 0; - for (int i = 0; i < token_begin_offsets.size(); ++i) { - int token_start = token_begin_offsets[i]; - int token_end = token_end_offsets[i]; - std::string potential_span_type = "O"; - bool recorded = false; - - while (span_index < span_begin_offsets.size() && !recorded) { - int span_start = span_begin_offsets[span_index]; - int span_end = span_end_offsets[span_index]; - - if (IsLeftOutsideSpan(token_start, token_end, span_start, span_end)) { - results.push_back(potential_span_type); - recorded = true; - } else if (IsRightOutsideSpan(token_start, token_end, span_start, - span_end)) { - span_index++; - } else if (IsStartOfSpan(token_start, token_end, span_start, span_end)) { - if (IsEndOfSpan(token_start, token_end, span_start, span_end)) { - results.push_back(absl::StrCat("S-", span_type[span_index])); - span_index++; - recorded = true; - } else { - if (use_strict_boundary_mode && token_start != span_start) { - results.push_back(potential_span_type); - recorded = true; - } else { - results.push_back(absl::StrCat("B-", span_type[span_index])); - recorded = true; - } - } - } else if (IsEndOfSpan(token_start, token_end, span_start, span_end)) { - if (use_strict_boundary_mode && token_end != span_end) { - results.push_back(potential_span_type); - recorded = true; - } else { - potential_span_type = absl::StrCat("E-", span_type[span_index]); - } - span_index++; - } else if (IsInsideSpan(token_start, token_end, span_start, span_end)) { - // token: |--) - // span: |---------) - results.push_back(absl::StrCat("I-", span_type[span_index])); - recorded = true; - } else { - // token: |----------) - // span: |----) - potential_span_type = absl::StrCat("B-", span_type[span_index]); - span_index++; - } - } - if (!recorded) { - results.push_back(potential_span_type); - } - } - return results; -} - -std::string ExtractSpanType(const std::string& tag) { - return std::string(absl::ClippedSubstr(tag, 2).data()); -} - -absl::StatusOr< - std::tuple, std::vector, std::vector>> -BoiseTagsToOffsets(const std::vector& token_begin_offsets, - const std::vector& token_end_offsets, - const std::vector& per_token_boise_tags) { - // Verify that input vectors are all the same size - if (token_begin_offsets.size() != token_end_offsets.size()) { - return absl::InvalidArgumentError("Tokens must have the same size"); - } - if (token_begin_offsets.size() != per_token_boise_tags.size()) { - return absl::InvalidArgumentError( - "Tokens and BOISE tags must have the same size"); - } - - std::vector span_start, span_end; - std::vector span_type; - // Iterate through each token - int potential_span_start = -1; - std::string potential_span_type; - bool started_span = false; - - for (int i = 0; i < token_begin_offsets.size(); ++i) { - // If we find a (B)egin, (I)nside, (E)nd, or (S)ingleton tag then - // record a span start. - const std::string& tag = per_token_boise_tags[i]; - - if (!started_span) { - if (absl::StartsWith(tag, "B-") || absl::StartsWith(tag, "I-")) { - potential_span_start = token_begin_offsets[i]; - started_span = true; - potential_span_type = ExtractSpanType(tag); - } - - if (absl::StartsWith(tag, "E-") || absl::StartsWith(tag, "S-")) { - // Treat this as a singleton - span_start.push_back(token_begin_offsets[i]); - span_end.push_back(token_end_offsets[i]); - span_type.push_back(ExtractSpanType(tag)); - started_span = false; - potential_span_type.clear(); - } - } else { - // If we have found a Outside, but we previously had a span start (from - // a Begin, or Inside) then treat this as a singleton and record an span - // end - if (absl::StartsWith(tag, "O")) { - span_start.push_back(potential_span_start); - span_end.push_back(token_end_offsets[i - 1]); - span_type.push_back(potential_span_type); - started_span = false; - potential_span_type.clear(); - } - - // If we find a End or Singleton then also record an end. - if (absl::StartsWith(tag, "E-") || absl::StartsWith(tag, "S-")) { - span_start.push_back(potential_span_start); - span_end.push_back(token_end_offsets[i]); - // Also record a span type. - span_type.push_back(ExtractSpanType(tag)); - started_span = false; - } - - // If we find a Begin, - if (absl::StartsWith(tag, "B-") || absl::StartsWith(tag, "I-")) { - // potential_span_start = token_begin_offsets[i]; - started_span = true; - potential_span_type = ExtractSpanType(tag); - } - } - } - - // Record span that has started but not closed. - if (started_span) { - span_start.push_back(potential_span_start); - span_end.push_back(token_end_offsets.back()); - span_type.push_back(potential_span_type); - } - - return std::tuple, std::vector, - std::vector>(span_start, span_end, span_type); -} - -std::unordered_set GetAllBoiseTagsFromSpanType( - const std::vector& span_type) { - std::unordered_set res{"O"}; - const std::unordered_set deduped_span_type(span_type.begin(), - span_type.end()); - const std::vector boise_prefixes = {"B-", "I-", "S-", "E-"}; - - for (const std::string& cur_span_type : deduped_span_type) { - if (cur_span_type.empty() || cur_span_type == "O") { - continue; - } - for (const std::string& prefix : boise_prefixes) { - std::string tag = absl::StrCat(prefix, cur_span_type); - res.insert(tag); - } - } - - return res; -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/boise_offset_converter.h b/tensorflow_text/core/kernels/boise_offset_converter.h index cfe21d128..73ef142c2 100644 --- a/tensorflow_text/core/kernels/boise_offset_converter.h +++ b/tensorflow_text/core/kernels/boise_offset_converter.h @@ -15,112 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BOISE_OFFSET_CONVERTER_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BOISE_OFFSET_CONVERTER_H_ -#include -#include - -#include "absl/status/statusor.h" - -namespace tensorflow { -namespace text { -// Translates span begin/end offsets and token begin/end offsets into a -// BOISE scheme. -// -// In the BOISE scheme there is a set of 5 labels for each type: -// - (B)egin: meaning the beginning of the span type. -// - (O)utside: meaning the token is outside of any span type -// - (I)nside: the token is inside the span -// - (S)ingleton: the entire span consists of this single token. -// - (E)nd: this token is the end of the span. -// -// When given the span begin & end offsets along with a set of token begin & end -// offsets, this function helps translate which each token into one of the 5 -// labels. -// -// For example, given the following example inputs: -// -// std::string content = "Who let the dogs out"; -// std::string entity = "dogs"; -// std::vector tokens = { "Who", "let", "the", "dogs", "out" } -// std::vector token_begin_offsets = { 0, 4, 8, 12, 17 }; -// std::vector token_end_offsets = { 3, 7, 11, 16, 20 }; -// std::vector span_begin_offsets = { 12 }; -// std::vector span_end_offsets = { 16 }; -// std::vector span_type = { "animal" } -// -// Foo will produce the following labels: -// { "O", "O", "O", "S-animal", "O", } -// | | | | | -// Who let the dogs out -// -// Special Case 1: Loose or Strict Boundary Criteria: -// By default, loose boundary criteria are used to decide token start and end, -// given a entity span. In the above example, say if we have -// -// std::vector span_begin_offsets = { 13 }; -// std::vector span_end_offsets = { 16 }; -// -// we still get { "O", "O", "O", "S-animal", "O", }, even though the span -// begin offset (13) is not exactly aligned with the token begin offset (12). -// Partial overlap between a token and a BOISE tag still qualify the token to -// be labeled with this tag. -// -// You can choose to use strict boundary criteria by passing in -// use_strict_boundary_mode = false argument, with which Foo will produce -// { "O", "O", "O", "O", "O", } for the case described above. -// -// Special Case 2: One Token Mapped to Multiple BOISE Tags: -// In cases where a token is overlapped with multiple BOISE tags, we label the -// token with the last tag. For example, given the following example inputs: -// -// std::string content = "Getty Center"; -// std::vector tokens = { "Getty Center" }; -// std::vector token_begin_offsets = { 0 }; -// std::vector token_end_offsets = { 12 }; -// std::vector span_begin_offsets = { 0, 6 }; -// std::vector span_end_offsets = { 5, 12 }; -// std::vector span_type = { "per", "loc" } -// -// Foo will produce the following labels: -// { "B-loc", } -absl::StatusOr> OffsetsToBoiseTags( - const std::vector& token_begin_offsets, - const std::vector& token_end_offsets, - const std::vector& span_begin_offsets, - const std::vector& span_end_offsets, - const std::vector& span_type, - const bool use_strict_boundary_mode = false); - -// Given the token offsets and BOISE tags per token, perform a translation -// that marks start offset, end offset and span type per entity. -// -// For example, given the following example inputs: -// -// std::vector token_begin_offsets = { 0, 4, 8, 12, 17 }; -// std::vector token_end_offsets = { 3, 7, 11, 16, 20 }; -// std::vector per_token_boise_tags = { "O", "O", "O", "S-animal", -// "O" }; -// -// Foo will produce the following offsets and labels vectors: -// start offsets: { 12, } -// end offsets: { 16, } -// span types: { "animal", } -absl::StatusOr< - std::tuple, std::vector, std::vector>> -BoiseTagsToOffsets(const std::vector& token_begin_offsets, - const std::vector& token_end_offsets, - const std::vector& per_token_boise_tags); - -// Get all possible BOISE tags for given span types. For example, -// -// std::vector span_type = { "loc", "per" } -// -// Foo will produce an unordered set: -// { "O", "B-loc", "I-loc", "S-loc", "E-loc", "B-per", "I-per", "S-per", -// "E-per", }. -std::unordered_set GetAllBoiseTagsFromSpanType( - const std::vector& span_type); - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/boise_offset_converter.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BOISE_OFFSET_CONVERTER_H_ diff --git a/tensorflow_text/core/kernels/boise_offset_converter_kernel.cc b/tensorflow_text/core/kernels/boise_offset_converter_kernel.cc deleted file mode 100644 index d2ed8c42a..000000000 --- a/tensorflow_text/core/kernels/boise_offset_converter_kernel.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/boise_offset_converter_kernel.h" - -#include "tensorflow/core/framework/op_kernel.h" - -namespace tensorflow { -namespace text { - -REGISTER_KERNEL_BUILDER( - Name(OffsetsToBoiseTagsOpKernel::OpName()).Device(tensorflow::DEVICE_CPU), - OffsetsToBoiseTagsOpKernel); - -REGISTER_KERNEL_BUILDER( - Name(BoiseTagsToOffsetsOpKernel::OpName()).Device(tensorflow::DEVICE_CPU), - BoiseTagsToOffsetsOpKernel); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/boise_offset_converter_kernel.h b/tensorflow_text/core/kernels/boise_offset_converter_kernel.h index e8a978d36..873d03d05 100644 --- a/tensorflow_text/core/kernels/boise_offset_converter_kernel.h +++ b/tensorflow_text/core/kernels/boise_offset_converter_kernel.h @@ -15,25 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BOISE_OFFSET_CONVERTER_KERNEL_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BOISE_OFFSET_CONVERTER_KERNEL_H_ -#include "tensorflow/lite/kernels/shim/tf_op_shim.h" -#include "tensorflow_text/core/kernels/boise_offset_converter_kernel_template.h" - -namespace tensorflow { -namespace text { - -class OffsetsToBoiseTagsOpKernel - : public tflite::shim::TfOpKernel { - public: - using TfOpKernel::TfOpKernel; -}; - -class BoiseTagsToOffsetsOpKernel - : public tflite::shim::TfOpKernel { - public: - using TfOpKernel::TfOpKernel; -}; - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/boise_offset_converter_kernel.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BOISE_OFFSET_CONVERTER_KERNEL_H_ diff --git a/tensorflow_text/core/kernels/boise_offset_converter_kernel_template.h b/tensorflow_text/core/kernels/boise_offset_converter_kernel_template.h index f49f059aa..9867b6c2c 100644 --- a/tensorflow_text/core/kernels/boise_offset_converter_kernel_template.h +++ b/tensorflow_text/core/kernels/boise_offset_converter_kernel_template.h @@ -15,625 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BOISE_OFFSET_CONVERTER_KERNEL_TEMPLATE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BOISE_OFFSET_CONVERTER_KERNEL_TEMPLATE_H_ -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "tensorflow/core/platform/tstring.h" -#include "tensorflow/lite/kernels/shim/op_kernel.h" -#include "tensorflow/lite/kernels/shim/shape.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow_text/core/kernels/boise_offset_converter.h" - -namespace tensorflow { -namespace text { - -template -class OffsetsToBoiseTagsOp - : public tflite::shim::OpKernelShim { - private: - enum Inputs { - kInputTokenBeginOffsets = 0, - kInputTokenEndOffsets, - kInputSpanBeginOffsets, - kInputSpanEndOffsets, - kInputSpanType, - kInputTokenBeginRowSplits, - kInputTokenEndRowSplits, - kInputSpanBeginRowSplits, - kInputSpanEndRowSplits, - kInputSpanTypeRowSplits, - kInputUseStrictBoundaryMode - }; - enum Outputs { kOutputBoiseTags = 0 }; - - using typename tflite::shim::OpKernelShim::InitContext; - using typename tflite::shim::OpKernelShim::InvokeContext; - using typename tflite::shim::OpKernelShim::ShapeInferenceContext; - - public: - OffsetsToBoiseTagsOp() = default; - static constexpr char kOpName[] = "TFText>OffsetsToBoiseTags"; - static constexpr char kDoc[] = R"doc( - Converts token/span begin/end offsets into BOISE tags. - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Attrs() { return {}; } - - // Inputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Inputs(); - - // Outputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Outputs(); - - // Initializes the op - absl::Status Init(InitContext* context) { return absl::OkStatus(); } - - // Runs the operation - absl::Status Invoke(InvokeContext* context); - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* c); -}; - -////////////////////////// Implementation - -template -std::vector OffsetsToBoiseTagsOp::Inputs() { - return {"input_token_begin_offsets: int32", - "input_token_end_offsets: int32", - "input_span_begin_offsets: int32", - "input_span_end_offsets: int32", - "input_span_type: string", - "input_token_begin_row_splits: int64", - "input_token_end_row_splits: int64", - "input_span_begin_row_splits: int64", - "input_span_end_row_splits: int64", - "input_span_type_row_splits: int64", - "input_use_strict_boundary_mode: bool"}; -} - -template -std::vector OffsetsToBoiseTagsOp::Outputs() { - return {"output_boise_tags: string"}; -} - -template -absl::Status OffsetsToBoiseTagsOp::ShapeInference( - ShapeInferenceContext* c) { - using tflite::shim::Shape; - const auto rank_1_shape = Shape({Shape::kUnknownDim}); - - SH_ASSIGN_OR_RETURN(const Shape input_token_begin_shape, - c->GetInputShape(kInputTokenBeginOffsets)); - if (!input_token_begin_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_token_begin_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN(const Shape input_token_end_shape, - c->GetInputShape(kInputTokenEndOffsets)); - if (!input_token_end_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_token_end_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN(const Shape input_span_begin_shape, - c->GetInputShape(kInputSpanBeginOffsets)); - if (!input_span_begin_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_span_begin_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN(const Shape input_span_end_shape, - c->GetInputShape(kInputSpanEndOffsets)); - if (!input_span_end_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_span_end_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN(const Shape input_span_type_shape, - c->GetInputShape(kInputSpanType)); - if (!input_span_type_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_span_type_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN(const Shape input_token_begin_rs_shape, - c->GetInputShape(kInputTokenBeginRowSplits)); - if (!input_token_begin_rs_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_token_begin_rs_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN(const Shape input_token_end_rs_shape, - c->GetInputShape(kInputTokenEndRowSplits)); - if (!input_token_end_rs_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_token_end_rs_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN(const Shape input_span_begin_rs_shape, - c->GetInputShape(kInputSpanBeginRowSplits)); - if (!input_span_begin_rs_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_span_begin_rs_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN(const Shape input_span_end_rs_shape, - c->GetInputShape(kInputSpanEndRowSplits)); - if (!input_span_end_rs_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_span_end_rs_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN(const Shape input_span_type_rs_shape, - c->GetInputShape(kInputSpanTypeRowSplits)); - if (!input_span_type_rs_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_span_type_rs_shape.ToString())); - } - - const int num_offsets = input_token_begin_shape.Dim(0); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputBoiseTags, Shape({num_offsets}))); - - return absl::OkStatus(); -} - -template -absl::Status OffsetsToBoiseTagsOp::Invoke(InvokeContext* context) { - // Inputs - SH_ASSIGN_OR_RETURN(const auto input_token_begin_offsets, - context->GetInput(kInputTokenBeginOffsets)); - const auto& input_token_begin_offsets_vec = - input_token_begin_offsets->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_token_end_offsets, - context->GetInput(kInputTokenEndOffsets)); - const auto& input_token_end_offsets_vec = - input_token_end_offsets->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_span_begin_offsets, - context->GetInput(kInputSpanBeginOffsets)); - const auto& input_span_begin_offsets_vec = - input_span_begin_offsets->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_span_end_offsets, - context->GetInput(kInputSpanEndOffsets)); - const auto& input_span_end_offsets_vec = - input_span_end_offsets->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_span_type, - context->GetInput(kInputSpanType)); - const auto& input_span_type_vec = - input_span_type->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_token_begin_row_splits, - context->GetInput(kInputTokenBeginRowSplits)); - const auto& input_token_begin_row_splits_vec = - input_token_begin_row_splits->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_token_end_row_splits, - context->GetInput(kInputTokenEndRowSplits)); - const auto& input_token_end_row_splits_vec = - input_token_end_row_splits->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_span_begin_row_splits, - context->GetInput(kInputSpanBeginRowSplits)); - const auto& input_span_begin_row_splits_vec = - input_span_begin_row_splits->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_span_end_row_splits, - context->GetInput(kInputSpanEndRowSplits)); - const auto& input_span_end_row_splits_vec = - input_span_end_row_splits->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_span_type_row_splits, - context->GetInput(kInputSpanTypeRowSplits)); - const auto& input_span_type_row_splits_vec = - input_span_type_row_splits->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_use_strict_boundary_mode, - context->GetInput(kInputUseStrictBoundaryMode)); - const bool input_use_strict_boundary_mode_value = - input_use_strict_boundary_mode->template AsScalar(); - - // Check token begin and end offsets match in size. - // Check span begin/end offsets, span type match in size. - if (input_token_begin_offsets_vec.Dim(0) != - input_token_end_offsets_vec.Dim(0) || - input_span_begin_offsets_vec.Dim(0) != - input_span_end_offsets_vec.Dim(0) || - input_span_begin_offsets_vec.Dim(0) != input_span_type_vec.Dim(0)) { - return absl::InvalidArgumentError(absl::StrCat( - "Token begin/end offsets must have the same size. Span begin/end " - "offsets and span type must have the same size.", - " Token begin offsets shape: ", input_token_begin_offsets_vec.Dim(0), - " Token end offsets shape: ", input_token_end_offsets_vec.Dim(0), - " Span begin offsets shape: ", input_span_begin_offsets_vec.Dim(0), - " Span end offsets shape: ", input_span_end_offsets_vec.Dim(0), - " Span type shape: ", input_span_type_vec.Dim(0))); - } - - // Check row splits are the same for token begin, end offsets. - if (input_token_begin_row_splits_vec.Dim(0) != - input_token_end_row_splits_vec.Dim(0) || - input_span_begin_row_splits_vec.Dim(0) != - input_span_begin_row_splits_vec.Dim(0) || - input_span_begin_row_splits_vec.Dim(0) != - input_span_end_row_splits_vec.Dim(0) || - input_span_begin_row_splits_vec.Dim(0) != - input_span_type_row_splits_vec.Dim(0)) { - return absl::InvalidArgumentError(absl::StrCat( - "Row splits must have the same size for token and span. ", - " Token begin row splits shape: ", - input_token_begin_row_splits_vec.Dim(0), - " Token end row splits shape: ", input_token_end_row_splits_vec.Dim(0), - " Span begin row splits shape: ", - input_span_begin_row_splits_vec.Dim(0), " Span end row splits shape: ", - input_span_end_row_splits_vec.Dim(0), " Span type row splits shape: ", - input_span_type_row_splits_vec.Dim(0))); - } - - for (int i = 0; i < input_token_begin_row_splits_vec.Dim(0) - 1; ++i) { - if (input_token_begin_row_splits_vec(i) != - input_token_end_row_splits_vec(i)) { - return absl::InvalidArgumentError( - "Row splits must be the same for token begin and end offsets."); - } - } - - // Check row splits are the same for span begin, end offsets and span type. - for (int i = 0; i < input_span_begin_row_splits_vec.Dim(0) - 1; ++i) { - if (input_span_begin_row_splits_vec(i) != - input_span_end_row_splits_vec(i) || - input_span_begin_row_splits_vec(i) != - input_span_type_row_splits_vec(i)) { - return absl::InvalidArgumentError( - "Row splits must be the same for span begin, end offsets and span " - "type."); - } - } - - // Outputs - std::vector boise_tags; - std::vector input_token_begin_offsets_vec_i; - std::vector input_token_end_offsets_vec_i; - std::vector input_span_begin_offsets_vec_i; - std::vector input_span_end_offsets_vec_i; - std::vector input_span_type_vec_i; - - // Iterate through all the input values and split them. - for (int i = 0; i < input_token_begin_row_splits_vec.Dim(0) - 1; ++i) { - int token_start_index = input_token_begin_row_splits_vec(i); - int token_end_index = input_token_begin_row_splits_vec(i + 1); - int span_start_index = input_span_begin_row_splits_vec(i); - int span_end_index = input_span_begin_row_splits_vec(i + 1); - - input_token_begin_offsets_vec_i.clear(); - input_token_end_offsets_vec_i.clear(); - input_span_begin_offsets_vec_i.clear(); - input_span_end_offsets_vec_i.clear(); - input_span_type_vec_i.clear(); - - for (int j = token_start_index; j < token_end_index; ++j) { - input_token_begin_offsets_vec_i.push_back( - input_token_begin_offsets_vec(j)); - input_token_end_offsets_vec_i.push_back(input_token_end_offsets_vec(j)); - } - for (int j = span_start_index; j < span_end_index; ++j) { - input_span_begin_offsets_vec_i.push_back(input_span_begin_offsets_vec(j)); - input_span_end_offsets_vec_i.push_back(input_span_end_offsets_vec(j)); - input_span_type_vec_i.push_back(input_span_type_vec(j)); - } - - SH_ASSIGN_OR_RETURN( - std::vector boise_tags_i, - OffsetsToBoiseTags( - input_token_begin_offsets_vec_i, input_token_end_offsets_vec_i, - input_span_begin_offsets_vec_i, input_span_end_offsets_vec_i, - input_span_type_vec_i, input_use_strict_boundary_mode_value)); - - for (int j = 0; j < boise_tags_i.size(); ++j) { - boise_tags.push_back(boise_tags_i[j]); - } - } - - // Allocate output & fill output tensors. - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - boise_tags, kOutputBoiseTags, context)); - - return absl::OkStatus(); -} - - -template -class BoiseTagsToOffsetsOp - : public tflite::shim::OpKernelShim { - private: - enum Inputs { - kInputTokenBeginOffsets = 0, - kInputTokenEndOffsets, - kInputBoiseTags, - kInputTokenBeginRowSplits, - kInputTokenEndRowSplits, - kInputBoiseTagsRowSplits, - }; - enum Outputs { - kOutputSpanBeginOffsets = 0, - kOutputSpanEndOffsets, - kOutputSpanType, - kOutputRowSplits, - }; - - using typename tflite::shim::OpKernelShim::InitContext; - using typename tflite::shim::OpKernelShim::InvokeContext; - using typename tflite::shim::OpKernelShim::ShapeInferenceContext; - - public: - BoiseTagsToOffsetsOp() = default; - static constexpr char kOpName[] = "TFText>BoiseTagsToOffsets"; - static constexpr char kDoc[] = R"doc( - Converts BOISE tags into span begin/end offsets and span type. - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Attrs() { return {}; } - - // Inputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Inputs(); - - // Outputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Outputs(); - - // Initializes the op - absl::Status Init(InitContext* context) { return absl::OkStatus(); } - - // Runs the operation - absl::Status Invoke(InvokeContext* context); - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* c); - - protected: - template - inline absl::Status FillOutputTensor(const std::vector& buffer, - int index, InvokeContext* context); -}; - -////////////////////////// Implementation - -template -std::vector BoiseTagsToOffsetsOp::Inputs() { - return {"input_token_begin_offsets: int32", - "input_token_end_offsets: int32", - "input_boise_tags: string", - "input_token_begin_row_splits: int64", - "input_token_end_row_splits: int64", - "input_boise_tags_row_splits: int64"}; -} - -template -std::vector BoiseTagsToOffsetsOp::Outputs() { - return {"output_span_begin_offsets: int32", "output_span_end_offsets: int32", - "output_span_type: string", "output_row_splits: int64"}; -} - -template -absl::Status BoiseTagsToOffsetsOp::ShapeInference( - ShapeInferenceContext* c) { - using tflite::shim::Shape; - const auto rank_1_shape = Shape({Shape::kUnknownDim}); - - SH_ASSIGN_OR_RETURN(const Shape input_token_begin_shape, - c->GetInputShape(kInputTokenBeginOffsets)); - if (!input_token_begin_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_token_begin_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN(const Shape input_token_end_shape, - c->GetInputShape(kInputTokenEndOffsets)); - if (!input_token_end_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_token_end_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN(const Shape input_boise_tags_shape, - c->GetInputShape(kInputBoiseTags)); - if (!input_boise_tags_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_boise_tags_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN(const Shape input_token_begin_rs_shape, - c->GetInputShape(kInputTokenBeginRowSplits)); - if (!input_token_begin_rs_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_token_begin_rs_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN(const Shape input_token_end_rs_shape, - c->GetInputShape(kInputTokenEndRowSplits)); - if (!input_token_end_rs_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_token_end_rs_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN(const Shape input_boise_tags_rs_shape, - c->GetInputShape(kInputBoiseTagsRowSplits)); - if (!input_boise_tags_rs_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_boise_tags_rs_shape.ToString())); - } - - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputSpanBeginOffsets, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputSpanEndOffsets, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputSpanType, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputRowSplits, rank_1_shape)); - - return absl::OkStatus(); -} - -template -absl::Status BoiseTagsToOffsetsOp::Invoke(InvokeContext* context) { - // Inputs - SH_ASSIGN_OR_RETURN(const auto input_token_begin_offsets, - context->GetInput(kInputTokenBeginOffsets)); - const auto& input_token_begin_offsets_vec = - input_token_begin_offsets->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_token_end_offsets, - context->GetInput(kInputTokenEndOffsets)); - const auto& input_token_end_offsets_vec = - input_token_end_offsets->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_boise_tags, - context->GetInput(kInputBoiseTags)); - const auto& input_boise_tags_vec = - input_boise_tags->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_token_begin_row_splits, - context->GetInput(kInputTokenBeginRowSplits)); - const auto& input_token_begin_row_splits_vec = - input_token_begin_row_splits->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_token_end_row_splits, - context->GetInput(kInputTokenEndRowSplits)); - const auto& input_token_end_row_splits_vec = - input_token_end_row_splits->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_boise_tags_row_splits, - context->GetInput(kInputBoiseTagsRowSplits)); - const auto& input_boise_tags_row_splits_vec = - input_boise_tags_row_splits->template As(); - - // Check token begin and end offsets, and boise tags match in size. - if (input_token_begin_offsets_vec.Dim(0) != - input_token_end_offsets_vec.Dim(0) || - input_token_begin_offsets_vec.Dim(0) != input_boise_tags_vec.Dim(0)) { - return absl::InvalidArgumentError(absl::StrCat( - "Token begin/end offsets and boise tags must have the same size. ", - " Token begin offsets shape: ", input_token_begin_offsets_vec.Dim(0), - " Token end offsets shape: ", input_token_end_offsets_vec.Dim(0), - " BOISE tags shape: ", input_boise_tags_vec.Dim(0))); - } - - // Check row splits are the same for token begin, end offsets and boise tags. - // First, check dimensions are the same. - if (input_token_begin_row_splits_vec.Dim(0) != - input_token_end_row_splits_vec.Dim(0) || - input_token_begin_row_splits_vec.Dim(0) != - input_boise_tags_row_splits_vec.Dim(0)) { - return absl::InvalidArgumentError(absl::StrCat( - "Row splits must have the same size for token begin/end offsets and " - "BOISE tags. ", - " Token begin row splits shape: ", - input_token_begin_row_splits_vec.Dim(0), - " Token end row splits shape: ", input_token_end_row_splits_vec.Dim(0), - " BOISE tags row splits shape: ", - input_boise_tags_row_splits_vec.Dim(0))); - } - // Second, check values are the same. - for (int i = 0; i < input_token_begin_row_splits_vec.Dim(0) - 1; ++i) { - if (input_token_begin_row_splits_vec(i) != - input_token_end_row_splits_vec(i) || - input_token_begin_row_splits_vec(i) != - input_boise_tags_row_splits_vec(i)) { - return absl::InvalidArgumentError( - "Row splits must be the same for token begin/end offsets ad BOISE " - "tags."); - } - } - - // Outputs - std::vector span_begin_offsets; - std::vector span_end_offsets; - std::vector span_type; - std::vector row_splits; - - row_splits.push_back(0); - - // Iterate through all the input values and split them. - std::vector input_token_begin_offsets_vec_i; - std::vector input_token_end_offsets_vec_i; - std::vector input_boise_tags_vec_i; - for (int i = 0; i < input_token_begin_row_splits_vec.Dim(0) - 1; ++i) { - int token_start_index = input_token_begin_row_splits_vec(i); - int token_end_index = input_token_begin_row_splits_vec(i + 1); - - input_token_begin_offsets_vec_i.clear(); - input_token_end_offsets_vec_i.clear(); - input_boise_tags_vec_i.clear(); - - for (int j = token_start_index; j < token_end_index; ++j) { - input_token_begin_offsets_vec_i.push_back( - input_token_begin_offsets_vec(j)); - input_token_end_offsets_vec_i.push_back(input_token_end_offsets_vec(j)); - input_boise_tags_vec_i.push_back(input_boise_tags_vec(j)); - } - - auto [span_begin_offsets_i, span_end_offsets_i, span_type_i] = - BoiseTagsToOffsets(input_token_begin_offsets_vec_i, - input_token_end_offsets_vec_i, - input_boise_tags_vec_i) - .value(); - - const int num_span_i = span_type_i.size(); - row_splits.push_back(row_splits.back() + num_span_i); - - for (int j = 0; j < span_type_i.size(); ++j) { - span_type.push_back(span_type_i[j]); - span_begin_offsets.push_back(span_begin_offsets_i[j]); - span_end_offsets.push_back(span_end_offsets_i[j]); - } - } - - // Allocate output & fill output tensors. - SH_RETURN_IF_ERROR(FillOutputTensor( - span_begin_offsets, kOutputSpanBeginOffsets, context)); - SH_RETURN_IF_ERROR(FillOutputTensor( - span_end_offsets, kOutputSpanEndOffsets, context)); - SH_RETURN_IF_ERROR(FillOutputTensor( - span_type, kOutputSpanType, context)); - SH_RETURN_IF_ERROR(FillOutputTensor( - row_splits, kOutputRowSplits, context)); - - return absl::OkStatus(); -} - -template -template -absl::Status BoiseTagsToOffsetsOp::FillOutputTensor( - const std::vector& buffer, const int index, - InvokeContext* context) { - SH_ASSIGN_OR_RETURN( - const auto tensorview, - context->GetOutput( - index, tflite::shim::Shape({static_cast(buffer.size())}))); - auto data = tensorview->template As(); - // TODO(broken): investigate using memcpy like previous WST - for (int i = 0; i < buffer.size(); ++i) data(i) = buffer.at(i); - return absl::OkStatus(); -} - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/boise_offset_converter_kernel_template.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BOISE_OFFSET_CONVERTER_KERNEL_TEMPLATE_H_ diff --git a/tensorflow_text/core/kernels/boise_offset_converter_test.cc b/tensorflow_text/core/kernels/boise_offset_converter_test.cc deleted file mode 100644 index 06c279d06..000000000 --- a/tensorflow_text/core/kernels/boise_offset_converter_test.cc +++ /dev/null @@ -1,561 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/boise_offset_converter.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include "absl/strings/string_view.h" - -using ::testing::ContainerEq; - -namespace tensorflow { -namespace text { -namespace { - -// Helper function to extract texts based on the begin and end offsets. -// content = "Who let the dogs out" -// begin_offsets = {12, 17} -// end_offsets = {16, 20} -// Foo returns: {"dogs", "out"} -std::vector ExtractTextsFromOffsets( - const std::string content, const std::vector begin_offsets, - const std::vector end_offsets) { - absl::string_view content_sv = absl::string_view(content); - std::vector res; - for (int i = 0; i < begin_offsets.size(); ++i) { - int text_len = end_offsets[i] - begin_offsets[i]; - res.push_back(static_cast( - content_sv.substr(begin_offsets[i], text_len))); - } - return res; -} - -// Test that we can transform offsets into BOISE tags -TEST(OffsetsToBoiseTagsTest, ExtractSingleton) { - // 1 2 - // 012345678901234567890 - std::string content = "Who let the dogs out"; - std::string entity = "dogs"; - std::vector token_begin_offsets = {0, 4, 8, 12, 17}; - std::vector token_end_offsets = {3, 7, 11, 16, 20}; - std::vector entity_begin_offsets = {12}; - std::vector entity_end_offsets = {16}; - std::vector entity_type = {"animal"}; - std::vector boise_tags = - OffsetsToBoiseTags(token_begin_offsets, token_end_offsets, - entity_begin_offsets, entity_end_offsets, entity_type) - .ValueOrDie(); - EXPECT_THAT(boise_tags, ContainerEq(std::vector{ - "O", "O", "O", "S-animal", "O"})); -} - -TEST(OffsetsToBoiseTagsTest, ExtractSingletonStrictBoundary) { - // 1 - // 01234567890123456789 - std::string content = "Who let the dogs out"; - std::string entity = "dogs"; - std::vector token_begin_offsets = {0, 4, 8, 12, 17}; - std::vector token_end_offsets = {3, 7, 11, 16, 20}; - std::vector entity_begin_offsets = {13}; - std::vector entity_end_offsets = {16}; - std::vector entity_type = {"animal"}; - bool use_strict_boundary_mode = true; - std::vector boise_tags = - OffsetsToBoiseTags(token_begin_offsets, token_end_offsets, - entity_begin_offsets, entity_end_offsets, entity_type, - use_strict_boundary_mode) - .ValueOrDie(); - EXPECT_THAT(boise_tags, - ContainerEq(std::vector{"O", "O", "O", "O", "O"})); -} - -TEST(OffsetsToBoiseTagsTest, ExtractBEEntity) { - // 1 2 3 - // 0123456789012345678901234567890 - std::string content = "Who let the german shepherd out"; - std::string entity = "german shepherd"; - std::vector token_begin_offsets = {0, 4, 8, 12, 19, 28}; - std::vector token_end_offsets = {3, 7, 11, 18, 27, 31}; - std::vector entity_begin_offsets = {12}; - std::vector entity_end_offsets = {27}; - std::vector entity_type = {"animal"}; - - std::vector boise_tags = - OffsetsToBoiseTags(token_begin_offsets, token_end_offsets, - entity_begin_offsets, entity_end_offsets, entity_type) - .ValueOrDie(); - EXPECT_THAT(boise_tags, ContainerEq(std::vector{ - "O", "O", "O", "B-animal", "E-animal", "O"})); -} - -TEST(OffsetsToBoiseTagsTest, ExtractBIEEntity) { - // 1 2 3 - // 0123456789012345678901234567890 - std::string content = "How big is Los Angeles County?"; - std::string entity = "Los Angeles County"; - std::vector token_begin_offsets = {0, 4, 8, 11, 15, 23, 29}; - std::vector token_end_offsets = {3, 7, 10, 14, 22, 29, 30}; - std::vector entity_begin_offsets = {11}; - std::vector entity_end_offsets = {29}; - std::vector entity_type = {"loc"}; - - std::vector boise_tags = - OffsetsToBoiseTags(token_begin_offsets, token_end_offsets, - entity_begin_offsets, entity_end_offsets, entity_type) - .ValueOrDie(); - EXPECT_THAT(boise_tags, ContainerEq(std::vector{ - "O", "O", "O", "B-loc", "I-loc", "E-loc", "O"})); -} - -TEST(OffsetsToBoiseTagsTest, ExtractMutipleEntities) { - // 1 2 3 - // 01234567890123456789012345678901234567 - std::string content = "Getty Center is in Los Angeles County"; - std::vector token_begin_offsets = {0, 6, 13, 16, 19, 23, 31}; - std::vector token_end_offsets = {5, 12, 15, 18, 22, 30, 37}; - std::vector entity_begin_offsets = {0, 19}; - std::vector entity_end_offsets = {12, 37}; - std::vector entity_type = {"org", "loc"}; - - std::vector boise_tags = - OffsetsToBoiseTags(token_begin_offsets, token_end_offsets, - entity_begin_offsets, entity_end_offsets, entity_type) - .ValueOrDie(); - EXPECT_THAT(boise_tags, - ContainerEq(std::vector{"B-org", "E-org", "O", "O", - "B-loc", "I-loc", "E-loc"})); -} - -TEST(OffsetsToBoiseTagsTest, LooseBoundary) { - // 1 2 3 - // 01234567890123456789012345678901234567 - std::string content = "Getty Center is in Los Angeles County"; - std::vector token_begin_offsets = {0, 6, 13, 16, 19, 23, 31}; - std::vector token_end_offsets = {5, 12, 15, 18, 22, 30, 37}; - std::vector entity_begin_offsets = {3, 19}; - std::vector entity_end_offsets = {10, 32}; - std::vector entity_type = {"org", "loc"}; - - std::vector boise_tags = - OffsetsToBoiseTags(token_begin_offsets, token_end_offsets, - entity_begin_offsets, entity_end_offsets, entity_type) - .ValueOrDie(); - EXPECT_THAT(boise_tags, - ContainerEq(std::vector{"B-org", "E-org", "O", "O", - "B-loc", "I-loc", "E-loc"})); -} - -TEST(OffsetsToBoiseTagsTest, StrictBoundary) { - // 1 2 3 - // 01234567890123456789012345678901234567 - std::string content = "Getty Center is in Los Angeles County"; - std::vector token_begin_offsets = {0, 6, 13, 16, 19, 23, 31}; - std::vector token_end_offsets = {5, 12, 15, 18, 22, 30, 37}; - std::vector entity_begin_offsets = {3, 19}; - std::vector entity_end_offsets = {12, 32}; - std::vector entity_type = {"org", "loc"}; - bool use_strict_boundary_mode = true; - - std::vector boise_tags = - OffsetsToBoiseTags(token_begin_offsets, token_end_offsets, - entity_begin_offsets, entity_end_offsets, entity_type, - use_strict_boundary_mode) - .ValueOrDie(); - EXPECT_THAT(boise_tags, ContainerEq(std::vector{ - "O", "E-org", "O", "O", "B-loc", "I-loc", "O"})); -} - -TEST(OffsetsToBoiseTagsTest, OneTokenMultiEntitiesLastPrecedes) { - // 1 - // 0123456789012 - std::string content = "Getty Center"; - std::vector token_begin_offsets = {0}; - std::vector token_end_offsets = {12}; - std::vector entity_begin_offsets = {0, 6}; - std::vector entity_end_offsets = {5, 12}; - std::vector entity_type = {"per", "loc"}; - - std::vector boise_tags = - OffsetsToBoiseTags(token_begin_offsets, token_end_offsets, - entity_begin_offsets, entity_end_offsets, entity_type) - .ValueOrDie(); - EXPECT_THAT(boise_tags, ContainerEq(std::vector{"B-loc"})); -} - -TEST(OffsetsToBoiseTagsTest, OneTokenMultEntitiesPartialOverlapLastPrecedes) { - // 1 - // 0123456789012 - std::string content = "Getty Center"; - std::vector token_begin_offsets = {0, 6}; - std::vector token_end_offsets = {5, 12}; - std::vector entity_begin_offsets = {0, 9}; - std::vector entity_end_offsets = {8, 12}; - std::vector entity_type = {"per", "loc"}; - - std::vector boise_tags = - OffsetsToBoiseTags(token_begin_offsets, token_end_offsets, - entity_begin_offsets, entity_end_offsets, entity_type) - .ValueOrDie(); - EXPECT_THAT(boise_tags, - ContainerEq(std::vector{"B-per", "B-loc"})); -} - -TEST(OffsetsToBoiseTagsTest, MultiTokensOneEntityPartialOverlapLastPrecedes) { - // 1 2 3 - // 01234567890123456789012345678901234 - std::string content = "Getty Center, Los Angeles County"; - std::vector token_begin_offsets = {0, 6, 14, 18, 26}; - std::vector token_end_offsets = {5, 12, 17, 25, 32}; - std::vector entity_begin_offsets = {0, 15}; - std::vector entity_end_offsets = {14, 30}; - std::vector entity_type = {"org", "loc"}; - - std::vector boise_tags = - OffsetsToBoiseTags(token_begin_offsets, token_end_offsets, - entity_begin_offsets, entity_end_offsets, entity_type) - .ValueOrDie(); - EXPECT_THAT(boise_tags, ContainerEq(std::vector{ - "B-org", "I-org", "B-loc", "I-loc", "E-loc"})); -} - -TEST(OffsetsToBoiseTagsTest, EmptySpanOffsets) { - std::vector token_begin_offsets = {0, 6, 13, 16, 19, 23, 31}; - std::vector token_end_offsets = {5, 12, 15, 18, 22, 30, 37}; - std::vector entity_begin_offsets = {}; - std::vector entity_end_offsets = {}; - std::vector entity_type = {}; - - std::vector boise_tags = - OffsetsToBoiseTags(token_begin_offsets, token_end_offsets, - entity_begin_offsets, entity_end_offsets, entity_type) - .ValueOrDie(); - EXPECT_THAT(boise_tags, ContainerEq(std::vector{ - "O", "O", "O", "O", "O", "O", "O"})); -} - -TEST(OffsetsToBoiseTagsTest, InputSizeError) { - std::vector token_begin_offsets = {0, 4, 8, 12, 17}; - std::vector token_end_offsets = {3, 7, 11, 16, 20}; - std::vector entity_begin_offsets = {12}; - std::vector entity_end_offsets = {16}; - std::vector entity_type = {"animal", "extra_entity"}; - EXPECT_FALSE(OffsetsToBoiseTags(token_begin_offsets, token_end_offsets, - entity_begin_offsets, entity_end_offsets, - entity_type) - .ok()); -} - -// Test that BOISE tags can be transformed into offets -TEST(BoiseTagsToOffsetTest, BeginAndEndTagsAreConvertedToOffsets) { - // 1 2 3 - // 0123456789012345678901234567890 - std::string content = "Who let the german shepherd out"; - std::vector token_begin_offsets = {0, 4, 8, 12, 19, 28}; - std::vector token_end_offsets = {3, 7, 11, 18, 27, 31}; - std::vector boise_tags = {"O", "O", "O", - "B-animal", "E-animal", "O"}; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - auto texts = ExtractTextsFromOffsets(content, begin_offsets, end_offsets); - EXPECT_THAT(texts, ContainerEq(std::vector{"german shepherd"})); - EXPECT_THAT(span_types, ContainerEq(std::vector{"animal"})); -} - -TEST(BoiseTagsToOffsetTest, SingletonTagsAreExtracted) { - // 1 2 - // 012345678901234567890 - std::string content = "Who let the dogs out"; - std::vector token_begin_offsets = {0, 4, 8, 12, 17}; - std::vector token_end_offsets = {3, 7, 11, 16, 20}; - std::vector boise_tags = {"O", "O", "O", "S-animal", "O"}; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - auto texts = ExtractTextsFromOffsets(content, begin_offsets, end_offsets); - EXPECT_THAT(texts, ContainerEq(std::vector{"dogs"})); - EXPECT_THAT(span_types, ContainerEq(std::vector{"animal"})); -} - -TEST(BoiseTagsToOffsetTest, BeginInsideAndEndLabelsAreExtracted) { - // 1 2 3 - // 0123456789012345678901234567890 - std::string content = "How big is Los Angeles County?"; - std::vector token_begin_offsets = {0, 4, 8, 11, 15, 23, 29}; - std::vector token_end_offsets = {3, 7, 10, 14, 22, 29, 30}; - std::vector boise_tags = {"O", "O", "O", "B-loc", - "I-loc", "E-loc", "O"}; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - auto texts = ExtractTextsFromOffsets(content, begin_offsets, end_offsets); - EXPECT_THAT(texts, - ContainerEq(std::vector{"Los Angeles County"})); - EXPECT_THAT(span_types, ContainerEq(std::vector{"loc"})); -} - -TEST(BoiseTagsToOffsetTest, InsideEndLabelsAreExtracted) { - // 1 2 3 - // 0123456789012345678901234567890 - std::string content = "Who let the german shepherd out"; - std::vector token_begin_offsets = {0, 4, 8, 12, 19, 28}; - std::vector token_end_offsets = {3, 7, 11, 18, 27, 31}; - std::vector boise_tags = {"O", "O", "O", - "I-animal", "E-animal", "O"}; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - auto texts = ExtractTextsFromOffsets(content, begin_offsets, end_offsets); - EXPECT_THAT(texts, ContainerEq(std::vector{"german shepherd"})); - EXPECT_THAT(span_types, ContainerEq(std::vector{"animal"})); -} - -TEST(BoiseTagsToOffsetTest, BeginInsideLabelsAreExtracted) { - // 1 2 3 - // 0123456789012345678901234567890 - std::string content = "Who let the german shepherd out"; - std::vector token_begin_offsets = {0, 4, 8, 12, 19, 28}; - std::vector token_end_offsets = {3, 7, 11, 18, 27, 31}; - std::vector boise_tags = {"O", "O", "O", - "B-animal", "I-animal", "O"}; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - auto texts = ExtractTextsFromOffsets(content, begin_offsets, end_offsets); - EXPECT_THAT(texts, ContainerEq(std::vector{"german shepherd"})); - EXPECT_THAT(span_types, ContainerEq(std::vector{"animal"})); -} - -TEST(BoiseTagsToOffsetTest, InsideOnlyLabelIsExtracted) { - // 1 2 3 - // 0123456789012345678901234567890 - std::string content = "Who let the shepherd out"; - std::vector token_begin_offsets = {0, 4, 8, 12, 21}; - std::vector token_end_offsets = {3, 7, 11, 20, 24}; - std::vector boise_tags = { - "O", "O", "O", "I-animal", "O", - }; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - auto texts = ExtractTextsFromOffsets(content, begin_offsets, end_offsets); - EXPECT_THAT(texts, ContainerEq(std::vector{"shepherd"})); - EXPECT_THAT(span_types, ContainerEq(std::vector{"animal"})); -} - -TEST(BoiseTagsToOffsetTest, BeginOnlyLabelIsExtracted) { - // 1 2 3 - // 0123456789012345678901234567890 - std::string content = "Who let the shepherd out"; - std::vector token_begin_offsets = {0, 4, 8, 12, 21}; - std::vector token_end_offsets = {3, 7, 11, 20, 24}; - std::vector boise_tags = { - "O", "O", "O", "B-animal", "O", - }; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - auto texts = ExtractTextsFromOffsets(content, begin_offsets, end_offsets); - EXPECT_THAT(texts, ContainerEq(std::vector{"shepherd"})); - EXPECT_THAT(span_types, ContainerEq(std::vector{"animal"})); -} - -TEST(BoiseTagsToOffsetTest, EndOnlyLabelIsExtracted) { - // 1 2 3 - // 0123456789012345678901234567890 - std::string content = "Who let the shepherd out"; - std::vector token_begin_offsets = {0, 4, 8, 12, 21}; - std::vector token_end_offsets = {3, 7, 11, 20, 24}; - std::vector boise_tags = { - "O", "O", "O", "E-animal", "O", - }; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - auto texts = ExtractTextsFromOffsets(content, begin_offsets, end_offsets); - EXPECT_THAT(texts, ContainerEq(std::vector{"shepherd"})); - EXPECT_THAT(span_types, ContainerEq(std::vector{"animal"})); -} - -TEST(BoiseTagsToOffsetTest, MultipleEntitiesAreExtracted) { - // 1 2 3 - // 01234567890123456789012345678901234567 - std::string content = "Getty Center is in Los Angeles County"; - std::vector token_begin_offsets = {0, 6, 13, 16, 19, 23, 31}; - std::vector token_end_offsets = {5, 12, 15, 18, 22, 30, 37}; - std::vector boise_tags = {"B-org", "E-org", "O", "O", - "B-loc", "I-loc", "E-loc"}; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - auto texts = ExtractTextsFromOffsets(content, begin_offsets, end_offsets); - EXPECT_THAT(texts, ContainerEq(std::vector{ - "Getty Center", "Los Angeles County"})); - EXPECT_THAT(span_types, ContainerEq(std::vector{"org", "loc"})); -} - -TEST(BoiseTagsToOffsetTest, MultipleBeginLabels) { - // 1 2 3 - // 0123456789012345678901234567890 - std::string content = "Who let the german shepherd out"; - std::vector token_begin_offsets = {0, 4, 8, 12, 19, 28}; - std::vector token_end_offsets = {3, 7, 11, 18, 27, 31}; - std::vector boise_tags = {"O", "O", "O", - "B-loc", "B-animal", "O"}; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - auto texts = ExtractTextsFromOffsets(content, begin_offsets, end_offsets); - EXPECT_THAT(texts, ContainerEq(std::vector{"german shepherd"})); - EXPECT_THAT(span_types, ContainerEq(std::vector{"animal"})); -} - -TEST(BoiseTagsToOffsetTest, MultipleInsideLabels) { - // 1 2 3 - // 0123456789012345678901234567890 - std::string content = "Who let the german shepherd out"; - std::vector token_begin_offsets = {0, 4, 8, 12, 19, 28}; - std::vector token_end_offsets = {3, 7, 11, 18, 27, 31}; - std::vector boise_tags = {"O", "O", "O", - "I-loc", "I-animal", "O"}; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - auto texts = ExtractTextsFromOffsets(content, begin_offsets, end_offsets); - EXPECT_THAT(texts, ContainerEq(std::vector{"german shepherd"})); - EXPECT_THAT(span_types, ContainerEq(std::vector{"animal"})); -} - -TEST(BoiseTagsToOffsetTest, MultipleEndLabels) { - // 1 2 3 - // 0123456789012345678901234567890 - std::string content = "Who let the german shepherd out"; - std::vector token_begin_offsets = {0, 4, 8, 12, 19, 28}; - std::vector token_end_offsets = {3, 7, 11, 18, 27, 31}; - std::vector boise_tags = {"O", "O", "O", - "E-loc", "E-animal", "O"}; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - auto texts = ExtractTextsFromOffsets(content, begin_offsets, end_offsets); - EXPECT_THAT(texts, - ContainerEq(std::vector{"german", "shepherd"})); - EXPECT_THAT(span_types, - ContainerEq(std::vector{"loc", "animal"})); -} - -TEST(BoiseTagsToOffsetTest, MultipleSingleLabels) { - // 1 2 3 - // 0123456789012345678901234567890 - std::string content = "Who let the german shepherd out"; - std::vector token_begin_offsets = {0, 4, 8, 12, 19, 28}; - std::vector token_end_offsets = {3, 7, 11, 18, 27, 31}; - std::vector boise_tags = {"O", "O", "O", - "S-loc", "S-animal", "O"}; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - auto texts = ExtractTextsFromOffsets(content, begin_offsets, end_offsets); - EXPECT_THAT(texts, - ContainerEq(std::vector{"german", "shepherd"})); - EXPECT_THAT(span_types, - ContainerEq(std::vector{"loc", "animal"})); -} - -TEST(BoiseTagsToOffsetTest, TrailingBeginLabels) { - // 1 2 3 - // 0123456789012345678901234567890 - std::string content = "Who own the german shepherd"; - std::vector token_begin_offsets = {0, 4, 8, 12, 19}; - std::vector token_end_offsets = {3, 7, 11, 18, 27}; - std::vector boise_tags = {"O", "O", "O", "B-loc", "B-animal"}; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - auto texts = ExtractTextsFromOffsets(content, begin_offsets, end_offsets); - EXPECT_THAT(texts, ContainerEq(std::vector{"german shepherd"})); - EXPECT_THAT(span_types, ContainerEq(std::vector{"animal"})); -} - -TEST(BoiseTagsToOffsetTest, NoBoiseLabels) { - std::vector token_begin_offsets = {0, 4, 8, 12, 19}; - std::vector token_end_offsets = {3, 7, 11, 16, 20}; - std::vector boise_tags = {"O", "O", "O", "O", "O"}; - - auto [begin_offsets, end_offsets, span_types] = - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ValueOrDie(); - - EXPECT_TRUE(begin_offsets.empty()); - EXPECT_TRUE(end_offsets.empty()); - EXPECT_TRUE(span_types.empty()); -} - -TEST(BoiseTagsToOffsetTest, InputSizeError) { - std::vector token_begin_offsets = {0, 4, 8, 12}; - std::vector token_end_offsets = {3, 7, 11, 16, 20}; - std::vector boise_tags = {"O", "O", "O", "B-loc", "B-animal"}; - EXPECT_FALSE( - BoiseTagsToOffsets(token_begin_offsets, token_end_offsets, boise_tags) - .ok()); -} - -TEST(GetAllBoiseTagsFromSpanTypeTest, GetAllTagsCorrect) { - std::vector span_type = {"loc", "O", "per", ""}; - std::unordered_set all_tags = - GetAllBoiseTagsFromSpanType(span_type); - EXPECT_THAT(all_tags, ContainerEq(std::unordered_set{ - "O", "B-loc", "I-loc", "S-loc", "E-loc", "B-per", - "I-per", "S-per", "E-per"})); -} - -} // namespace -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/byte_splitter.cc b/tensorflow_text/core/kernels/byte_splitter.cc deleted file mode 100644 index df7201717..000000000 --- a/tensorflow_text/core/kernels/byte_splitter.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/byte_splitter.h" - -#include - -namespace tensorflow { -namespace text { - -void ByteSplitter::Split(const absl::string_view input, - std::vector* bytes, - std::vector* start_offsets, - std::vector* end_offsets) const { - if (input.empty()) return; - Split(input, bytes); - start_offsets->push_back(0); - for (int i = 1; i < input.size(); ++i) { - start_offsets->push_back(i); - end_offsets->push_back(i); - } - end_offsets->push_back(input.size()); -} - -void ByteSplitter::Split(const absl::string_view input, - std::vector* bytes, - std::vector* offsets) const { - if (input.empty()) return; - Split(input, bytes); - for (int i = 0; i <= input.size(); ++i) { - offsets->push_back(i); - } -} - -void ByteSplitter::Split(const absl::string_view input, - std::vector* bytes) const { - for (const auto& c : input) { - bytes->push_back(c); - } -} - -absl::StatusOr> ByteSplitter::SplitByOffsets( - absl::string_view input, - absl::Span start_offsets, - absl::Span end_offsets) const { - std::vector result; - int num = std::min(start_offsets.size(), end_offsets.size()); - for (int i = 0; i < num; ++i) { - if (start_offsets[i] < 0 || start_offsets[i] > input.size()) { - return absl::InvalidArgumentError("Start offsets out of range."); - } - if (end_offsets[i] < 0 || end_offsets[i] > input.size()) { - return absl::InvalidArgumentError("End offsets out of range."); - } - if (start_offsets[i] > end_offsets[i]) { - return absl::InvalidArgumentError("Start offset after end offset."); - } - result.push_back(input.substr(start_offsets[i], - end_offsets[i] - start_offsets[i])); - } - return result; -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/byte_splitter.h b/tensorflow_text/core/kernels/byte_splitter.h index 954b7ccfd..46582ac2a 100644 --- a/tensorflow_text/core/kernels/byte_splitter.h +++ b/tensorflow_text/core/kernels/byte_splitter.h @@ -12,95 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_TOKENIZER_H_ -#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_TOKENIZER_H_ +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_H_ -#include -#include +#include "tensorflow/core/kernels/text/byte_splitter.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" - -namespace tensorflow { -namespace text { - -class ByteSplitter { - public: - // Creates an instance. - ByteSplitter() { } - - // Tokenizes a string into bytes. - // - // Example: - // input = "uñ" - // bytes = [117, 195, 177] - // start_offsets = [0, 1, 2] - // end_offsets = [1, 2, 3] - // - // Args: - // * input: The string of an input. - // * bytes: The output bytes. - // * start_offsets: The start offsets of output bytes in the input text. - // * end_offsets: The end offsets of output bytes in the input text. - // Note: the start offsets are inclusive and the end offsets are exclusive. - void Split(const absl::string_view input, - std::vector* bytes, - std::vector* start_offsets, - std::vector* end_offsets) const; - - // Tokenizes a string into bytes. - // - // Example: - // input = "uñ" - // bytes = [117, 195, 177] - // offsets = [0, 1, 2, 3] - // - // Args: - // * input: The string of an input. - // * bytes: The output bytes. - // * offsets: The offsets of output bytes in the input text. The size will - // be one plus the input. Each value is the mapped offset of each byte of - // the original input text. The final value maps the end. - // Note: the start offsets are inclusive and the end offsets are exclusive. - void Split(const absl::string_view input, - std::vector* bytes, - std::vector* offsets) const; - - // Tokenizes a string into bytes. - // - // Example: - // input = "uñ" - // bytes = [117, 195, 177] - // - // Args: - // * input: The string of an input. - // * bytes: The output bytes. - void Split(const absl::string_view input, - std::vector* bytes) const; - - // Splits a string by the given start and end offsets. - // - // Example: - // input = "uñ" - // start_offsets = [0, 1] - // end_offsets = [1, 3] - // string = ["u", "ñ"] - // - // Args: - // * input: The string of an input. - // * start_offsets: Input byte index where the new strings start (inclusive). - // * end_offsets: Input byte index where the new strings end. (exclusive) - // - // Return: - // The split substrings. - absl::StatusOr> SplitByOffsets( - absl::string_view input, - absl::Span start_offsets, - absl::Span end_offsets) const; -}; - -} // namespace text -} // namespace tensorflow - - -#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_TOKENIZER_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_H_ diff --git a/tensorflow_text/core/kernels/byte_splitter_kernel.cc b/tensorflow_text/core/kernels/byte_splitter_kernel.cc deleted file mode 100644 index b05d7d29c..000000000 --- a/tensorflow_text/core/kernels/byte_splitter_kernel.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/byte_splitter_kernel.h" - -#include "tensorflow/core/framework/op_kernel.h" - -namespace tensorflow { -namespace text { - -REGISTER_KERNEL_BUILDER(Name(ByteSplitterWithOffsetsOpKernel::OpName()) - .Device(tensorflow::DEVICE_CPU), - ByteSplitterWithOffsetsOpKernel); - -REGISTER_KERNEL_BUILDER(Name(ByteSplitByOffsetsOpKernel::OpName()) - .Device(tensorflow::DEVICE_CPU), - ByteSplitByOffsetsOpKernel); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/byte_splitter_kernel.h b/tensorflow_text/core/kernels/byte_splitter_kernel.h index c3e3df413..10e999f51 100644 --- a/tensorflow_text/core/kernels/byte_splitter_kernel.h +++ b/tensorflow_text/core/kernels/byte_splitter_kernel.h @@ -15,25 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_KERNEL_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_KERNEL_H_ -#include "tensorflow/lite/kernels/shim/tf_op_shim.h" -#include "tensorflow_text/core/kernels/byte_splitter_kernel_template.h" - -namespace tensorflow { -namespace text { - -class ByteSplitterWithOffsetsOpKernel - : public tflite::shim::TfOpKernel { - public: - using TfOpKernel::TfOpKernel; -}; - -class ByteSplitByOffsetsOpKernel - : public tflite::shim::TfOpKernel { - public: - using TfOpKernel::TfOpKernel; -}; - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/byte_splitter_kernel.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_KERNEL_H_ diff --git a/tensorflow_text/core/kernels/byte_splitter_kernel_template.h b/tensorflow_text/core/kernels/byte_splitter_kernel_template.h index 77ab2b1ba..c61f6e7f9 100644 --- a/tensorflow_text/core/kernels/byte_splitter_kernel_template.h +++ b/tensorflow_text/core/kernels/byte_splitter_kernel_template.h @@ -15,299 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_KERNEL_TEMPLATE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_KERNEL_TEMPLATE_H_ -#include -#include - -#include "absl/status/status.h" -#include "tensorflow/core/platform/tstring.h" -#include "tensorflow/lite/kernels/shim/op_kernel.h" -#include "tensorflow/lite/kernels/shim/shape.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow_text/core/kernels/byte_splitter.h" - -namespace tensorflow { -namespace text { - -template -class ByteSplitterWithOffsetsOp - : public tflite::shim::OpKernelShim { - private: - enum Inputs { - kInputValues = 0 - }; - enum Outputs { - kOutputBytes = 0, - kOutputRowSplits, - kOutputStartOffsets, - kOutputEndOffsets - }; - - using typename tflite::shim::OpKernelShim::InitContext; - using typename tflite::shim::OpKernelShim::InvokeContext; - using typename tflite::shim::OpKernelShim::ShapeInferenceContext; - - public: - ByteSplitterWithOffsetsOp() = default; - static constexpr char kOpName[] = "TFText>ByteSplitWithOffsets"; - static constexpr char kDoc[] = R"doc( - Splits a string into bytes - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Attrs() { return {}; } - - // Inputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Inputs(); - - // Outputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Outputs(); - - // Initializes the op - absl::Status Init(InitContext* context) { return absl::OkStatus(); } - - // Runs the operation - absl::Status Invoke(InvokeContext* context); - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* c); -}; - -template -std::vector ByteSplitterWithOffsetsOp::Inputs() { - return {"input_values: string"}; -} - -template -std::vector ByteSplitterWithOffsetsOp::Outputs() { - return {"output_bytes: uint8", "output_row_splits: int64", - "output_start_offsets: int32", "output_end_offsets: int32"}; -} - -template -absl::Status ByteSplitterWithOffsetsOp::ShapeInference( - ShapeInferenceContext* c) { - using tflite::shim::Shape; - const auto rank_1_shape = Shape({Shape::kUnknownDim}); - - SH_ASSIGN_OR_RETURN(const Shape& input_values_shape, - c->GetInputShape(kInputValues)); - if (!input_values_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Input values shape must be rank 1: ", - input_values_shape.ToString())); - } - - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputBytes, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputStartOffsets, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputEndOffsets, rank_1_shape)); - const int num_splits = Shape::AddDims(1, input_values_shape.Dim(0)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputRowSplits, Shape({num_splits}))); - - return absl::OkStatus(); -} - -template - absl::Status ByteSplitterWithOffsetsOp - ::Invoke(InvokeContext* context) { - // Inputs - SH_ASSIGN_OR_RETURN(const auto values_view, context->GetInput(kInputValues)); - const auto values = values_view->template As(); - - ByteSplitter splitter; - - // Outputs - std::vector bytes; - std::vector row_splits; - std::vector start_offsets; - std::vector end_offsets; - - // Iterate through all the string values and split them. - row_splits.push_back(0); - for (int i = 0; i < values.Dim(0); ++i) { - // Split into bytes and record the offset locations. - const int orig_num_bytes = bytes.size(); - splitter.Split(values(i), &bytes, &start_offsets, &end_offsets); - const int delta_num_bytes = bytes.size() - orig_num_bytes; - // Record the row splits. - row_splits.push_back(delta_num_bytes + row_splits.back()); - } - - // Allocate output & fill output tensors. - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - bytes, kOutputBytes, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - row_splits, kOutputRowSplits, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - start_offsets, kOutputStartOffsets, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - end_offsets, kOutputEndOffsets, context)); - - return absl::OkStatus(); -} - - -template -class ByteSplitByOffsetsOp - : public tflite::shim::OpKernelShim { - private: - enum Inputs { - kInputValues = 0, - kInputStartOffsets, - kInputEndOffsets, - kInputRowSplits - }; - enum Outputs { - kOutputValues = 0, - kOutputRowSplits, - }; - - using typename tflite::shim::OpKernelShim::InitContext; - using typename tflite::shim::OpKernelShim::InvokeContext; - using typename tflite::shim::OpKernelShim::ShapeInferenceContext; - - public: - ByteSplitByOffsetsOp() = default; - static constexpr char kOpName[] = "TFText>ByteSplitByOffsets"; - static constexpr char kDoc[] = R"doc( - Splits a string into bytes using the given start and end offsets. - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Attrs() { return {}; } - - // Inputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Inputs(); - - // Outputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Outputs(); - - // Initializes the op - absl::Status Init(InitContext* context) { return absl::OkStatus(); } - - // Runs the operation - absl::Status Invoke(InvokeContext* context); - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* c); -}; - -template -std::vector ByteSplitByOffsetsOp::Inputs() { - return {"input_values: string", "input_start_offsets: int32", - "input_end_offsets: int32", "input_row_splits: int64"}; -} - -template -std::vector ByteSplitByOffsetsOp::Outputs() { - return {"output_values: string", "output_row_splits: int64"}; -} - -template -absl::Status ByteSplitByOffsetsOp::ShapeInference( - ShapeInferenceContext* c) { - using tflite::shim::Shape; - const auto rank_1_shape = Shape({Shape::kUnknownDim}); - // input values shape - SH_ASSIGN_OR_RETURN(const Shape& input_values_shape, - c->GetInputShape(kInputValues)); - if (!input_values_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Input values shape must be rank 1: ", - input_values_shape.ToString())); - } - // input starts shape - SH_ASSIGN_OR_RETURN(const Shape& input_starts_shape, - c->GetInputShape(kInputStartOffsets)); - if (!input_starts_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Input start offsets shape must be rank 1: ", - input_starts_shape.ToString())); - } - // input ends shape - SH_ASSIGN_OR_RETURN(const Shape& input_ends_shape, - c->GetInputShape(kInputEndOffsets)); - if (!input_ends_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Input end offsets shape must be rank 1: ", - input_ends_shape.ToString())); - } - // input row splits shape - SH_ASSIGN_OR_RETURN(const Shape& input_row_splits_shape, - c->GetInputShape(kInputRowSplits)); - if (!input_row_splits_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Input row splits shape must be rank 1: ", - input_row_splits_shape.ToString())); - } - - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputValues, input_starts_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputRowSplits, - input_row_splits_shape)); - - return absl::OkStatus(); -} - -template - absl::Status ByteSplitByOffsetsOp - ::Invoke(InvokeContext* context) { - // Inputs - SH_ASSIGN_OR_RETURN(const auto input_values_view, - context->GetInput(kInputValues)); - const auto input_values = - input_values_view->template As(); - SH_ASSIGN_OR_RETURN(const auto starts_view, - context->GetInput(kInputStartOffsets)); - const auto starts = starts_view->template As(); - SH_ASSIGN_OR_RETURN(const auto ends_view, - context->GetInput(kInputEndOffsets)); - const auto ends = ends_view->template As(); - SH_ASSIGN_OR_RETURN(const auto in_splits_view, - context->GetInput(kInputRowSplits)); - const auto in_splits = in_splits_view->template As(); - - ByteSplitter splitter; - - // Outputs - std::vector output_values; - std::vector out_splits; - - // Iterate through all the string values and split them. - out_splits.push_back(0); - for (int i = 0; i < input_values.Dim(0); ++i) { - SH_ASSIGN_OR_RETURN(auto batch, - splitter.SplitByOffsets( - input_values(i), - absl::MakeSpan(starts.Ptr() + in_splits(i), - in_splits(i+1) - in_splits(i)), - absl::MakeSpan(ends.Ptr() + in_splits(i), - in_splits(i+1) - in_splits(i)))); - output_values.insert(output_values.end(), batch.begin(), batch.end()); - out_splits.push_back(batch.size() + out_splits.back()); - } - - // Allocate output & fill output tensors. - SH_RETURN_IF_ERROR( - this->template FillOutputTensor( - output_values, kOutputValues, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - out_splits, kOutputRowSplits, context)); - - return absl::OkStatus(); -} - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/byte_splitter_kernel_template.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_KERNEL_TEMPLATE_H_ diff --git a/tensorflow_text/core/kernels/byte_splitter_test.cc b/tensorflow_text/core/kernels/byte_splitter_test.cc deleted file mode 100644 index c6b4d6e72..000000000 --- a/tensorflow_text/core/kernels/byte_splitter_test.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/byte_splitter.h" - -#include - -#include - -namespace tensorflow { -namespace text { -namespace { - -using ::testing::ElementsAre; - -TEST(ByteSplitterTest, SplitAscii) { - const absl::string_view input_string("hello"); - std::vector output_bytes; - std::vector output_offsets; - ByteSplitter s; - s.Split(input_string, &output_bytes, &output_offsets); - EXPECT_THAT(output_bytes, ElementsAre(104, 101, 108, 108, 111)); - EXPECT_THAT(output_offsets, ElementsAre(0, 1, 2, 3, 4, 5)); -} - -TEST(ByteSplitterTest, SplitUnicode) { - const absl::string_view input_string("muñdʓ"); - std::vector output_bytes; - std::vector output_offsets; - ByteSplitter s; - s.Split(input_string, &output_bytes, &output_offsets); - EXPECT_THAT(output_bytes, ElementsAre(109, 117, 195, 177, 100, 202, 147)); - EXPECT_THAT(output_offsets, ElementsAre(0, 1, 2, 3, 4, 5, 6, 7)); -} - -TEST(ByteSplitterTest, SplitEmoji) { - const absl::string_view input_string("😀🙃"); - std::vector output_bytes; - std::vector output_offsets; - ByteSplitter s; - s.Split(input_string, &output_bytes, &output_offsets); - EXPECT_THAT(output_bytes, - ElementsAre(240, 159, 152, 128, 240, 159, 153, 131)); - EXPECT_THAT(output_offsets, ElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8)); -} - -TEST(ByteSplitterTest, SplitHanzi) { - const absl::string_view input_string("你好"); - std::vector output_bytes; - std::vector output_offsets; - ByteSplitter s; - s.Split(input_string, &output_bytes, &output_offsets); - EXPECT_THAT(output_bytes, ElementsAre(228, 189, 160, 229, 165, 189)); - EXPECT_THAT(output_offsets, ElementsAre(0, 1, 2, 3, 4, 5, 6)); -} - -TEST(ByteSplitterTest, SplitByBytesHanzi) { - ByteSplitter s; - auto output = s.SplitByOffsets("你好", {0, 3}, {3, 6}); - EXPECT_TRUE(output.ok()); - EXPECT_THAT(output.value(), ElementsAre("你", "好")); -} - -} // namespace -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/byte_splitter_tflite.cc b/tensorflow_text/core/kernels/byte_splitter_tflite.cc deleted file mode 100644 index a733467dd..000000000 --- a/tensorflow_text/core/kernels/byte_splitter_tflite.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/byte_splitter_tflite.h" - -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/shim/tflite_op_shim.h" -#include "tensorflow_text/core/kernels/byte_splitter_kernel_template.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -extern "C" void AddByteSplit(tflite::MutableOpResolver* resolver) { - tflite::shim::TfLiteOpKernel< - tensorflow::text::ByteSplitterWithOffsetsOp>::Add(resolver); -} - -extern "C" void AddByteSplitByOffsets(tflite::MutableOpResolver* resolver) { - tflite::shim::TfLiteOpKernel< - tensorflow::text::ByteSplitByOffsetsOp>::Add(resolver); -} - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow_text/core/kernels/byte_splitter_tflite.h b/tensorflow_text/core/kernels/byte_splitter_tflite.h index c1219ecc7..73304a13c 100644 --- a/tensorflow_text/core/kernels/byte_splitter_tflite.h +++ b/tensorflow_text/core/kernels/byte_splitter_tflite.h @@ -15,21 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_TFLITE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_TFLITE_H_ -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/mutable_op_resolver.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -extern "C" void AddByteSplit(::tflite::MutableOpResolver* resolver); - -extern "C" void AddByteSplitByOffsets(::tflite::MutableOpResolver* resolver); - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite +#include "tensorflow/core/kernels/text/byte_splitter_tflite.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_TFLITE_H_ diff --git a/tensorflow_text/core/kernels/constrained_sequence.cc b/tensorflow_text/core/kernels/constrained_sequence.cc deleted file mode 100644 index 2553472c1..000000000 --- a/tensorflow_text/core/kernels/constrained_sequence.cc +++ /dev/null @@ -1,441 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/constrained_sequence.h" - -#include -#include -#include -#include - -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace text { - -// State index to use if the sequence in question requires an impossible -// transition. -constexpr int kErrorState = -1; - -ScoreAccessor::ScoreAccessor(const Tensor &score_tensor, - const Tensor &lengths_tensor) { - data_ = score_tensor.flat().data(); - if (lengths_tensor.dtype() == DT_INT64) { - use_long_lengths_ = true; - long_lengths_ = lengths_tensor.flat().data(); - } else { - use_long_lengths_ = false; - lengths_ = lengths_tensor.flat().data(); - } - has_explicit_batch_ = (score_tensor.shape().dims() == 3); - if (has_explicit_batch_) { - batch_size_ = score_tensor.shape().dim_size(0); - num_steps_ = score_tensor.shape().dim_size(1); - num_scores_ = score_tensor.shape().dim_size(2); - } else { - batch_size_ = 1; - num_steps_ = score_tensor.shape().dim_size(0); - num_scores_ = score_tensor.shape().dim_size(1); - } - batch_offset_ = num_scores_ * num_steps_; - step_offset_ = num_scores_; -} - -// Get a score out of the data tensor. -float ScoreAccessor::GetScore(int batch_idx, int step_idx, - int score_idx) const { - DCHECK_LE(batch_idx, batch_size_); - DCHECK_LE(step_idx, num_steps_); - DCHECK_LE(score_idx, num_scores_); - return data_[batch_offset_ * batch_idx + step_offset_ * step_idx + score_idx]; -} - -int64 ScoreAccessor::GetLength(int batch_idx) const { - DCHECK_LE(batch_idx, batch_size_); - if (use_long_lengths_) { - return long_lengths_[batch_idx]; - } else { - return lengths_[batch_idx]; - } -} - -int ScoreAccessor::batch_size() const { return batch_size_; } -int ScoreAccessor::num_steps() const { return num_steps_; } -int ScoreAccessor::num_scores() const { return num_scores_; } -bool ScoreAccessor::has_explicit_batch() const { return has_explicit_batch_; } - -// Perform Viterbi analysis on a single batch item. -void ViterbiAnalysis( - const ScoreAccessor &scores, - const tensorflow::TTypes::Matrix &transition_weights, - const tensorflow::TTypes::Matrix &allowed_transitions, - const int batch, bool use_log_space, bool use_start_end_states, - int32 *output_data) { - VLOG(2) << "Analyzing batch " << batch; - const bool has_transition_weights = transition_weights.size() != 0; - const bool has_allowed_transitions = allowed_transitions.size() != 0; - const int num_states = scores.num_scores(); - const int out_of_bounds_index = num_states; - - int64 num_steps = scores.GetLength(batch); - - // Create two vectors to hold scores. These will be bound to referents later - // so the names here are somewhat irrelevant. - std::vector scores_a(num_states, - std::numeric_limits::lowest()); - std::vector scores_b(num_states, - std::numeric_limits::lowest()); - - // Create a chart of backpointers. Include rows for [start] and [end] - // transitions. By initializing this to kErrorState, we ensure unreachable - // transitions get marked as errors. - std::vector> backpointers( - num_steps, std::vector(num_states, kErrorState)); - - // Set current and previous references for step 0 - std::vector *previous_scores = &scores_a; - std::vector *current_scores = &scores_b; - - const bool vlog3 = VLOG_IS_ON(3); - - if (backpointers.empty()) { - // We're done with this batch if there are no steps to analyze. - return; - } - for (int curr_state = 0; curr_state < num_states; ++curr_state) { - std::vector ¤t_bps = backpointers[0]; - if (use_start_end_states) { - // Initialize the zeroth step BPs to kOutOfBoundsIndex for all states - // where the OOB->state transition is valid, and set scores as needed. - if (has_allowed_transitions && - !allowed_transitions(out_of_bounds_index, curr_state)) { - if (vlog3) { - LOG(INFO) << "(" << batch << ", 0, [START]->" << curr_state - << "): disallowed."; - } - continue; - } - - // Because the backpointer vectors are initialized to kErrorState, we - // need only to set the valid transition paths to have come from the - // padding state. - current_bps[curr_state] = out_of_bounds_index; - - // For valid transitions, get the score (and adjust as appropriate). - const int step = 0; - float current_score = scores.GetScore(batch, step, curr_state); - if (has_transition_weights) { - if (use_log_space) { - current_score += transition_weights(out_of_bounds_index, curr_state); - } else { - current_score *= transition_weights(out_of_bounds_index, curr_state); - } - } - - if (vlog3) { - if (has_transition_weights) { - LOG(INFO) << "(" << batch << ", " << step << ", [START]->" - << curr_state << "): Total score: " << current_score - << " (raw: " << scores.GetScore(batch, step, curr_state) - << ", tw: " - << transition_weights(out_of_bounds_index, curr_state) - << ")"; - } else { - LOG(INFO) << "(" << batch << ", " << step << ", [START]->" - << curr_state << "): Total score: " << current_score - << " (raw: " << scores.GetScore(batch, step, curr_state) - << ")"; - } - } - - current_scores->at(curr_state) = current_score; - } else { - // If we don't have specific start and end states, all bp's are valid - // and all starting scores are the unadjusted step 0 scores. - current_bps[curr_state] = out_of_bounds_index; - const int step = 0; - current_scores->at(curr_state) = scores.GetScore(batch, step, curr_state); - } - } - - // Update the current scores (and normalize if we're not in log space). - if (!use_log_space) { - const double max_score = - *std::max_element(current_scores->begin(), current_scores->end()); - if (max_score > 0) { - for (double &score : *current_scores) score /= max_score; - } - } - - // Swap current and previous score arrays, as we are advancing a step. - std::vector *tmp = previous_scores; - previous_scores = current_scores; - current_scores = tmp; - - // Handle all steps save for the first and last in this loop. - for (int step = 1; step < num_steps; ++step) { - const std::vector &previous_bps = backpointers[step - 1]; - std::vector ¤t_bps = backpointers[step]; - - for (int curr_state = 0; curr_state < num_states; ++curr_state) { - int best_source_state = kErrorState; - float best_score = std::numeric_limits::lowest(); - for (int prev_state = 0; prev_state < num_states; ++prev_state) { - // If the previous state was an error state, pass to the next state. - if (previous_bps[prev_state] == kErrorState) { - if (vlog3) { - LOG(INFO) << "(" << batch << ", " << step << ", " << prev_state - << "->" << curr_state << "): prev state error."; - } - continue; - } - - // If this is not a permitted transition, continue. - if (has_allowed_transitions && - !allowed_transitions(prev_state, curr_state)) { - if (vlog3) { - LOG(INFO) << "(" << batch << ", " << step << ", " << prev_state - << "->" << curr_state << "): disallowed."; - } - continue; - } - - float current_score = scores.GetScore(batch, step, curr_state); - if (use_log_space) { - current_score += previous_scores->at(prev_state); - } else { - current_score *= previous_scores->at(prev_state); - } - if (has_transition_weights) { - if (use_log_space) { - current_score += transition_weights(prev_state, curr_state); - } else { - current_score *= transition_weights(prev_state, curr_state); - } - } - - if (vlog3) { - if (has_transition_weights) { - LOG(INFO) << "(" << batch << ", " << step << ", " << prev_state - << "->" << curr_state - << "): Total score: " << current_score - << " (prev: " << previous_scores->at(prev_state) - << ", raw: " << scores.GetScore(batch, step, curr_state) - << ", tw: " << transition_weights(prev_state, curr_state) - << ")"; - } else { - LOG(INFO) << "(" << batch << ", " << step << ", " << prev_state - << "->" << curr_state - << "): Total score: " << current_score - << " (prev: " << previous_scores->at(prev_state) - << ", raw: " << scores.GetScore(batch, step, curr_state) - << ")"; - } - } - - if (current_score >= best_score) { - best_source_state = prev_state; - best_score = current_score; - } - } - current_bps[curr_state] = best_source_state; - current_scores->at(curr_state) = best_score; - } - - // Normalize if we're not in log space. - if (!use_log_space) { - const double max_score = - *std::max_element(current_scores->begin(), current_scores->end()); - if (max_score > 0) { - for (double &score : *current_scores) score /= max_score; - } - } - - // After each step, switch the current scores to the previous scores and - // use the previous previous scores as the current scores. - std::vector *tmp = previous_scores; - previous_scores = current_scores; - current_scores = tmp; - } - - // Handle the final transition out of the sequence. - int final_state = out_of_bounds_index; - const std::vector &previous_bps = backpointers[num_steps - 1]; - int best_source_state = kErrorState; - float final_score = std::numeric_limits::lowest(); - - for (int prev_state = 0; prev_state < num_states; ++prev_state) { - // If the previous state was an error state, pass to the next state. - if (previous_bps[prev_state] == kErrorState) { - current_scores->at(prev_state) = std::numeric_limits::lowest(); - if (vlog3) { - LOG(INFO) << "(" << batch << ", " << num_steps << ", " << prev_state - << "->[END]): prev state error."; - } - continue; - } - - // If this is not a permitted transition, continue. - if (has_allowed_transitions && use_start_end_states && - !allowed_transitions(prev_state, final_state)) { - current_scores->at(prev_state) = std::numeric_limits::lowest(); - if (vlog3) { - LOG(INFO) << "(" << batch << ", " << num_steps << ", " << prev_state - << "->[END]): disallowed."; - } - continue; - } - - // Weight the final transition score by the probability of exiting the - // sequence as well. - float current_score = previous_scores->at(prev_state); - if (use_start_end_states) { - if (has_transition_weights) { - if (use_log_space) { - current_score += transition_weights(prev_state, final_state); - } else { - current_score *= transition_weights(prev_state, final_state); - } - } - - if (vlog3) { - if (has_transition_weights) { - LOG(INFO) << "(" << batch << ", " << num_steps << ", " << prev_state - << "->[END]): Total score: " << current_score - << " (prev: " << previous_scores->at(prev_state) - << ", tw: " << transition_weights(prev_state, final_state) - << ")"; - } else { - LOG(INFO) << "(" << batch << ", " << num_steps << ", " << prev_state - << "->[END]): Total score: " << current_score - << " (prev: " << previous_scores->at(prev_state) << ")"; - } - } - } - - current_scores->at(prev_state) = current_score; - if (current_score >= final_score) { - best_source_state = prev_state; - final_score = current_score; - } - } - - if (vlog3) { - LOG(INFO) << "Final score: " << final_score; - } - - // Calculate the path. - if (best_source_state == kErrorState) { - // If the best source is an error state, the path is unknowable. Report - // error states for the whole sequence. - for (int64 i = 0; i < scores.GetLength(batch); ++i) { - output_data[i] = kErrorState; - } - } else { - // If the best source is a 'real' state, report the state path. - int steps_to_report = scores.GetLength(batch); - int previous_state = best_source_state; - for (int64 i = steps_to_report - 1; i >= 0; --i) { - output_data[i] = previous_state; - previous_state = backpointers[i][previous_state]; - } - } -} - -void GreedyAnalysis( - const ScoreAccessor &scores, - const tensorflow::TTypes::Matrix &transition_weights, - const tensorflow::TTypes::Matrix &allowed_transitions, - int batch, bool use_log_space, bool use_start_end_states, - int32 *output_data) { - const bool has_transition_weights = transition_weights.size() != 0; - const bool has_allowed_transitions = allowed_transitions.size() != 0; - const int num_states = scores.num_scores(); - const int out_of_bounds_index = num_states; - int64 num_steps = scores.GetLength(batch); - - for (int step = 0; step < num_steps; ++step) { - // Do final step calculations if this is the final step in the sequence - // and we are calculating based on implicit start and end states. - bool do_final_step = - (step == scores.GetLength(batch) - 1) && use_start_end_states; - VLOG(2) << "is last step: " << do_final_step; - - const int previous_state = - (step == 0) ? (out_of_bounds_index) : (output_data[step - 1]); - - if (previous_state == kErrorState) { - // If the previous state is the error state, the current state must - // also be the error state. - output_data[step] = kErrorState; - continue; - } - - // If no transition is possible, this will stay the error state. - int best_new_state = kErrorState; - float best_new_score = std::numeric_limits::lowest(); - - for (int state = 0; state < num_states; ++state) { - float current_score = scores.GetScore(batch, step, state); - - // If we are not using start/end states AND step is 0, then - // current_score will not be altered. - if (use_start_end_states || step > 0) { - if (has_allowed_transitions) { - // If either the transition from the previous state to this state - // is disallowed, or we need to analyze the final step and the - // transition from this state to the final step is not allowed, - // disallow this transition. - if (!allowed_transitions(previous_state, state) || - (do_final_step && - !allowed_transitions(state, out_of_bounds_index))) { - continue; - } - } - - if (has_transition_weights) { - if (use_log_space) { - current_score += transition_weights(previous_state, state); - } else { - current_score *= transition_weights(previous_state, state); - } - // On the last step, also analyze by the weight value of - // transitioning from this state to the out-of-bounds state. - if (do_final_step) { - if (use_log_space) { - current_score += transition_weights(state, out_of_bounds_index); - } else { - current_score *= transition_weights(state, out_of_bounds_index); - } - } - } - } - if (current_score >= best_new_score) { - best_new_state = state; - best_new_score = current_score; - } - } - output_data[step] = best_new_state; - VLOG(2) << "Best state for step " << step << " is " << output_data[step] - << " with score " << best_new_score; - } -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/constrained_sequence.h b/tensorflow_text/core/kernels/constrained_sequence.h index 5f62f46b6..120f0ea65 100644 --- a/tensorflow_text/core/kernels/constrained_sequence.h +++ b/tensorflow_text/core/kernels/constrained_sequence.h @@ -12,81 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_TEXT_CORE_KERNELS_CONSTRAINED_SEQUENCE_H_ -#define TENSORFLOW_TEXT_CORE_KERNELS_CONSTRAINED_SEQUENCE_H_ +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_CONSTRAINED_SEQUENCE_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_CONSTRAINED_SEQUENCE_H_ -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/kernels/text/constrained_sequence.h" -namespace tensorflow { -namespace text { - -class ScoreAccessor { - public: - explicit ScoreAccessor(const Tensor &score_tensor, - const Tensor &lengths_tensor); - - // Get a score out of the data tensor. - float GetScore(int batch_idx, int step_idx, int score_idx) const; - - int64 GetLength(int batch_idx) const; - - int batch_size() const; - int num_steps() const; - int num_scores() const; - bool has_explicit_batch() const; - - private: - // A pointer into the underlying data of the score tensor. Not owned. - const float *data_; - - // A pointer into the underlying data of the lengths tensor. Not owned. - const int *lengths_; - const int64 *long_lengths_; - - // Whether the passed lengths tensor is int32 or int64. - bool use_long_lengths_; - - // The batch size associated with the data tensor. - int batch_size_; - - // The number of steps in the data tensor. - int num_steps_; - - // The number of scores in the data tensor. - int num_scores_; - - // The amount to increase the offset within the flat data array if the batch - // index increases by 1. - int batch_offset_; - - // The amount to increase the offset within the flat data array if the step - // index increases by 1. - int step_offset_; - - // True if the original tensor had an explicit batch dimension (that is, - // it was of rank 3). - bool has_explicit_batch_; -}; - -// Perform Viterbi analysis on a single batch item. -void ViterbiAnalysis( - const ScoreAccessor &scores, - const tensorflow::TTypes::Matrix &transition_weights, - const tensorflow::TTypes::Matrix &allowed_transitions, - const int batch, bool use_log_space, bool use_start_end_states, - int32 *output_data); - -// Perform a greedy analysis on a single batch item. -void GreedyAnalysis( - const ScoreAccessor &scores, - const tensorflow::TTypes::Matrix &transition_weights, - const tensorflow::TTypes::Matrix &allowed_transitions, - int batch, bool use_log_space, bool use_start_end_states, - int32 *output_data); - -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_TEXT_CORE_KERNELS_CONSTRAINED_SEQUENCE_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_CONSTRAINED_SEQUENCE_H_ diff --git a/tensorflow_text/core/kernels/constrained_sequence_kernel.cc b/tensorflow_text/core/kernels/constrained_sequence_kernel.cc deleted file mode 100644 index 339c202ce..000000000 --- a/tensorflow_text/core/kernels/constrained_sequence_kernel.cc +++ /dev/null @@ -1,259 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow_text/core/kernels/constrained_sequence.h" - -namespace tensorflow { - -using ::tensorflow::DataType; -using ::tensorflow::DEVICE_CPU; -using ::tensorflow::DT_BOOL; -using ::tensorflow::DT_FLOAT; -using ::tensorflow::OpKernel; -using ::tensorflow::OpKernelConstruction; -using ::tensorflow::OpKernelContext; -using ::tensorflow::Status; -using ::tensorflow::Tensor; -using ::tensorflow::TensorShape; -using ::tensorflow::errors::InvalidArgument; -using ::tensorflow::text::GreedyAnalysis; -using ::tensorflow::text::ScoreAccessor; -using ::tensorflow::text::ViterbiAnalysis; - -// State index to use if the sequence in question requires an impossible -// transition. -constexpr int kErrorState = -1; - -// State index to use when outputting a padded tensor and the sequence in -// question does not have a token for a given step. -constexpr int kPaddingState = -2; - -namespace { - -// Validate that a given constraint tensor is the proper shape (dimension -// 2, with shape [num_states + 1, num_states + 1]. -absl::Status ValidateConstraintTensor(const Tensor &tensor, - const int num_states, - const bool use_start_end_states, - const string &name) { - if (tensor.shape().dims() != 2) { - return InvalidArgument( - tensorflow::strings::StrCat(name, " must be of rank 2")); - } - int expected_size = use_start_end_states ? num_states + 1 : num_states; - if (tensor.shape().dim_size(0) != expected_size) { - return InvalidArgument(tensorflow::strings::StrCat( - name, " must have a zeroth dimension of size ", expected_size, - " when num_states is ", num_states, " and use_start_and_end_states is ", - use_start_end_states)); - } - if (tensor.shape().dim_size(1) != expected_size) { - return InvalidArgument(tensorflow::strings::StrCat( - name, " must have a first dimension of size ", expected_size, - " when num_states is ", num_states, " and use_start_and_end_states is ", - use_start_end_states)); - } - return absl::OkStatus(); -} - -} // namespace - -template -class ConstrainedSequence : public OpKernel { - public: - explicit ConstrainedSequence(OpKernelConstruction *context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("use_viterbi", &use_viterbi_)); - OP_REQUIRES_OK(context, context->GetAttr("use_log_space", &use_log_space_)); - OP_REQUIRES_OK(context, context->GetAttr("use_start_and_end_states", - &use_start_end_states_)); - } - - void Compute(OpKernelContext *context) override { - const auto &score_tensor = context->input(0); - OP_REQUIRES(context, - (score_tensor.shape().dims() == 2) || - (score_tensor.shape().dims() == 3), - InvalidArgument("The score tensor must be of rank 2 or 3.")); - const auto &lengths_tensor = context->input(1); - - ScoreAccessor scores(score_tensor, lengths_tensor); - - // The scores tensor should be [batch, step, scores]. - const int batch_size = scores.batch_size(); - const int num_steps = scores.num_steps(); - const int num_scores = scores.num_scores(); - - OP_REQUIRES(context, lengths_tensor.NumElements() == batch_size, - InvalidArgument(tensorflow::strings::StrCat( - "There should be exactly one length for every batch " - "element. Found ", - lengths_tensor.NumElements(), - " length elements for a batch size of ", batch_size))); - - VLOG(2) << "batch: " << batch_size; - VLOG(2) << "steps: " << num_steps; - VLOG(2) << "score: " << num_scores; - - // Make sure there's enough data to advance every sequence. - int max_length = 0; - int total_length = 0; - for (int i = 0; i < batch_size; ++i) { - int64 length = scores.GetLength(i); - total_length += length; - if (length > max_length) { - max_length = length; - } - } - - OP_REQUIRES( - context, num_steps >= max_length, - InvalidArgument( - "The scores tensor is too short for the longest sequence length.")); - - // Validate the constraint tensors. - const auto &allowed_transitions_tensor = context->input(2); - bool has_allowed_transitions = - allowed_transitions_tensor.NumElements() != 0; - VLOG(4) << allowed_transitions_tensor.NumElements(); - if (has_allowed_transitions) { - OP_REQUIRES_OK(context, - ValidateConstraintTensor(allowed_transitions_tensor, - num_scores, use_start_end_states_, - "allowed_transitions")); - } - - const auto &transition_weights_tensor = context->input(3); - - VLOG(4) << transition_weights_tensor.NumElements(); - bool has_transition_weights = transition_weights_tensor.NumElements() != 0; - if (has_transition_weights) { - OP_REQUIRES_OK(context, ValidateConstraintTensor( - transition_weights_tensor, num_scores, - use_start_end_states_, "transition_weights")); - - // If we have transition weights in exp-space, all values must be non- - // negative. - if (!use_log_space_) { - for (int i = 0; i < transition_weights_tensor.NumElements(); ++i) { - OP_REQUIRES(context, transition_weights_tensor.flat()(i) >= 0, - InvalidArgument("The transition weights tensor must not " - "contain negative values.")); - } - } - } - - const tensorflow::Tensor empty_float(DT_FLOAT, TensorShape({0, 0})); - const tensorflow::Tensor empty_bool(DT_BOOL, TensorShape({0, 0})); - - const auto &transition_weights = - has_transition_weights ? transition_weights_tensor.matrix() - : empty_float.matrix(); - - const auto &allowed_transitions = - has_allowed_transitions ? allowed_transitions_tensor.matrix() - : empty_bool.matrix(); - - Tensor *output; - OP_REQUIRES_OK(context, context->allocate_output( - 0, TensorShape({total_length}), &output)); - int32 *output_data = output->flat().data(); - - Tensor *offsets; - OP_REQUIRES_OK(context, context->allocate_output( - 1, TensorShape({batch_size + 1}), &offsets)); - Tsplits *offset_data = offsets->flat().data(); - offset_data[0] = 0; - - for (int batch = 0; batch < batch_size; ++batch) { - int step_offset = offset_data[batch]; - int64 num_steps = scores.GetLength(batch); - offset_data[batch + 1] = step_offset + num_steps; - if (use_viterbi_) { - DoViterbiAnalysis(transition_weights, allowed_transitions, batch, - scores, &output_data[step_offset]); - } else { - DoGreedyAnalysis(transition_weights, allowed_transitions, batch, scores, - &output_data[step_offset]); - } - } - } - - private: - // Perform Viterbi analysis on a single batch item. - void DoViterbiAnalysis( - const tensorflow::TTypes::Matrix &transition_weights, - const tensorflow::TTypes::Matrix &allowed_transitions, - const int batch, const ScoreAccessor &scores, int32 *output_data) { - ViterbiAnalysis(scores, transition_weights, allowed_transitions, batch, - use_log_space_, use_start_end_states_, output_data); - } - - // Perform a greedy analysis on a single batch item. - void DoGreedyAnalysis( - const tensorflow::TTypes::Matrix &transition_weights, - const tensorflow::TTypes::Matrix &allowed_transitions, - int batch, const ScoreAccessor &scores, int32 *output_data) { - GreedyAnalysis(scores, transition_weights, allowed_transitions, batch, - use_log_space_, use_start_end_states_, output_data); - } - - // True if this op should perform calculations in log-space (using addition). - // If false, will perform calculations in normalized exp-space (using - // multiplication). - bool use_log_space_; - - // True if this op should calculate scores using the Viterbi algorithm. If - // false, will use a greedy algorithm. - bool use_viterbi_; - - // True if this op should calculate sequences based on an implicit start - // and end state. - bool use_start_end_states_; - - TF_DISALLOW_COPY_AND_ASSIGN(ConstrainedSequence); -}; - -#define REGISTER_KERNELS(Tin) \ - REGISTER_KERNEL_BUILDER(Name("ConstrainedSequence") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("Tin") \ - .TypeConstraint("Tsplits"), \ - ConstrainedSequence); \ - REGISTER_KERNEL_BUILDER(Name("ConstrainedSequence") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("Tin") \ - .TypeConstraint("Tsplits"), \ - ConstrainedSequence) - -REGISTER_KERNELS(int32); -REGISTER_KERNELS(int64); - -#undef REGISTER_KERNELS - -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/constrained_sequence_kernel_input_validation_test.cc b/tensorflow_text/core/kernels/constrained_sequence_kernel_input_validation_test.cc deleted file mode 100644 index 343d2d142..000000000 --- a/tensorflow_text/core/kernels/constrained_sequence_kernel_input_validation_test.cc +++ /dev/null @@ -1,496 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow_text/core/kernels/text_kernels_test_util.h" - -namespace tensorflow { - -using tensorflow::DT_INT32; -using tensorflow::FakeInput; -using tensorflow::NodeDefBuilder; -using tensorflow::Status; -using tensorflow::TensorShape; -using tensorflow::text_kernels_test_util::MatrixEq; -using tensorflow::text_kernels_test_util::VectorEq; - -class ConstrainedSequenceInputValidationTest : public tensorflow::OpsTestBase { - public: - void SetUpOpWithDefaults(bool use_start_end, - tensorflow::DataType input_datatype) { - // Prepare graph. - TF_ASSERT_OK(NodeDefBuilder("tested_op", "ConstrainedSequence") - .Attr("Tin", input_datatype) - .Attr("use_viterbi", true) - .Attr("use_log_space", true) - .Attr("use_start_and_end_states", use_start_end) - .Input(FakeInput()) - .Input(FakeInput()) - .Input(FakeInput()) - .Input(FakeInput()) - .Finalize(node_def())); - TF_ASSERT_OK(InitOp()); - } - - void SetUpOpWithStartEnd() { SetUpOpWithDefaults(true, DT_INT32); } - - void SetUpOpWithNoStartEnd() { SetUpOpWithDefaults(false, DT_INT32); } -}; -// TODO(b/122968457): There are a bunch of tests that only validate !ok instead -// of looking for specific error messages; fix that. - -// This test examines evaluations with only a permissions matrix. -TEST_F(ConstrainedSequenceInputValidationTest, WorksWithInt64InputLengths) { - // Prepare graph. - SetUpOpWithDefaults(true, DT_INT64); - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - std::vector input_lengths({1, 1, 1}); - AddInputFromArray(TensorShape({3}), input_lengths); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -TEST_F(ConstrainedSequenceInputValidationTest, - FailsOnOuterWrongSizePermissionMatrix) { - // Prepare graph. - SetUpOpWithStartEnd(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 3.0, 4.0, // - 1.0, 12.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({4, 5}), - { - true, true, true, true, true, // - true, true, true, true, true, // - true, true, true, true, true, // - true, true, true, true, true, // - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} -TEST_F(ConstrainedSequenceInputValidationTest, - FailsOnInnerWrongSizePermissionMatrix) { - // Prepare graph. - SetUpOpWithStartEnd(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 3.0, 4.0, // - 1.0, 12.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 4}), - { - true, true, true, true, true, // - true, true, true, true, true, // - true, true, true, true, true, // - true, true, true, true, true, // - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} -TEST_F(ConstrainedSequenceInputValidationTest, - FailsOnWrongRankPermissionMatrix) { - // Prepare graph. - SetUpOpWithStartEnd(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 3.0, 4.0, // - 1.0, 12.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({25}), - { - true, true, true, true, true, // - true, true, true, true, true, // - true, true, true, true, true, // - true, true, true, true, true, // - true, true, true, true, true, // - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} - -TEST_F(ConstrainedSequenceInputValidationTest, - FailsOnOuterWrongSizeWeightMatrix) { - // Prepare graph. - SetUpOpWithStartEnd(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 3.0, 4.0, // - 1.0, 12.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({4, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} -TEST_F(ConstrainedSequenceInputValidationTest, - FailsOnInnerWrongSizeWeightMatrix) { - // Prepare graph. - SetUpOpWithStartEnd(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 3.0, 4.0, // - 1.0, 12.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 4}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} -TEST_F(ConstrainedSequenceInputValidationTest, FailsOnWrongRankWeightMatrix) { - // Prepare graph. - SetUpOpWithStartEnd(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 3.0, 4.0, // - 1.0, 12.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({25}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} - -TEST_F(ConstrainedSequenceInputValidationTest, - PassesWithCorrectSizedWeightAndPermissionsMatrix) { - // Prepare graph. - SetUpOpWithNoStartEnd(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 3.0, 4.0, // - 1.0, 12.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({4, 4}), { - true, true, true, true, // - true, true, true, true, // - true, true, true, true, // - true, true, true, true, // - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({4, 4}), {0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 1.0, 1.0}); - auto result = RunOpKernel(); - EXPECT_TRUE(result.ok()); -} - -TEST_F(ConstrainedSequenceInputValidationTest, - FailsOnOuterWrongSizePermissionMatrixWithNoStartEnd) { - // Prepare graph. - SetUpOpWithNoStartEnd(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 3.0, 4.0, // - 1.0, 12.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({4, 5}), - { - true, true, true, true, true, // - true, true, true, true, true, // - true, true, true, true, true, // - true, true, true, true, true, // - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} -TEST_F(ConstrainedSequenceInputValidationTest, - FailsOnInnerWrongSizePermissionMatrixWithNoStartEnd) { - // Prepare graph. - SetUpOpWithNoStartEnd(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 3.0, 4.0, // - 1.0, 12.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 4}), - { - true, true, true, true, true, // - true, true, true, true, true, // - true, true, true, true, true, // - true, true, true, true, true, // - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} -TEST_F(ConstrainedSequenceInputValidationTest, - FailsOnWrongRankPermissionMatrixWithNoStartEnd) { - // Prepare graph. - SetUpOpWithNoStartEnd(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 3.0, 4.0, // - 1.0, 12.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({16}), { - true, true, true, true, // - true, true, true, true, // - true, true, true, true, // - true, true, true, true, // - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} - -TEST_F(ConstrainedSequenceInputValidationTest, - FailsOnOuterWrongSizeWeightMatrixWithNoStartEnd) { - // Prepare graph. - SetUpOpWithNoStartEnd(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 3.0, 4.0, // - 1.0, 12.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({4, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} -TEST_F(ConstrainedSequenceInputValidationTest, - FailsOnInnerWrongSizeWeightMatrixWithNoStartEnd) { - // Prepare graph. - SetUpOpWithNoStartEnd(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 3.0, 4.0, // - 1.0, 12.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 4}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} -TEST_F(ConstrainedSequenceInputValidationTest, - FailsOnWrongRankWeightMatrixWithNoStartEnd) { - // Prepare graph. - SetUpOpWithNoStartEnd(); - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 3.0, 4.0, // - 1.0, 12.0, 3.0, 4.0, // - 1.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({16}), {0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 1.0}); - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} - -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/darts_clone_trie_builder.cc b/tensorflow_text/core/kernels/darts_clone_trie_builder.cc deleted file mode 100644 index e204a9783..000000000 --- a/tensorflow_text/core/kernels/darts_clone_trie_builder.cc +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/darts_clone_trie_builder.h" - -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/strings/str_cat.h" -#include "include/darts.h" - -namespace tensorflow { -namespace text { -namespace trie_utils { - -absl::StatusOr> BuildDartsCloneTrie( - const std::vector& keys) { - std::vector values(keys.size()); - std::iota(values.begin(), values.end(), 0); - return BuildDartsCloneTrie(keys, values); -} - -absl::StatusOr> BuildDartsCloneTrie( - const std::vector& keys, const std::vector& values) { - if (keys.size() != values.size()) { - return absl::InvalidArgumentError(absl::StrCat( - "The sizes of 'keys' and 'values' must be equal! Keys size: ", - keys.size(), " . Values size: ", values.size())); - } - - { - // Make sure there are no duplicated elements or empty strings in 'keys'. - absl::flat_hash_set unique_keys; - for (const auto& key : keys) { - if (key.empty()) { - return absl::InvalidArgumentError( - "The empty string \"\" is found in 'keys', which is not " - "supported."); - } - if (!unique_keys.insert(key).second) { - return absl::InvalidArgumentError( - absl::StrCat("Duplicated key: ", key, ".")); - } - } - } - - // Make sure all values are non-negative. - for (int i = 0; i < keys.size(); ++i) { - if (values[i] < 0) { - return absl::InvalidArgumentError(absl::StrCat( - "All values must be non-negative! Found value: ", values[i], - " for key: ", keys[i], ", at index: ", i)); - } - } - - // Create a vector to hold the indexes. - std::vector vocab_index_sorted(keys.size()); - std::iota(vocab_index_sorted.begin(), vocab_index_sorted.end(), 0); - - // Sort the index by keys. - std::sort( - vocab_index_sorted.begin(), vocab_index_sorted.end(), - [&keys](const int x, const int y) { return keys.at(x) < keys.at(y); }); - - // Create vectors to build the trie. - std::vector trie_keys; - std::vector trie_values; - trie_keys.reserve(keys.size()); - trie_values.reserve(keys.size()); - for (const auto index : vocab_index_sorted) { - trie_keys.push_back(keys.at(index).c_str()); - trie_values.push_back(values[index]); - } - - // Build the trie. - auto trie = std::make_unique(); - trie->build(trie_keys.size(), const_cast(&trie_keys[0]), nullptr, - const_cast(&trie_values[0])); - - // Return the data of darts_clone (an array of 32-bit unsigned int). - const uint32_t* trie_array = static_cast(trie->array()); - return std::vector(trie_array, trie_array + trie->size()); -} - -} // namespace trie_utils -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/darts_clone_trie_builder.h b/tensorflow_text/core/kernels/darts_clone_trie_builder.h index 2557134f0..c8d777c1d 100644 --- a/tensorflow_text/core/kernels/darts_clone_trie_builder.h +++ b/tensorflow_text/core/kernels/darts_clone_trie_builder.h @@ -12,42 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Builder utils for Darts-clone tries. -// -// Darts-clone is a compact and efficient implementation of Darts (Double-ARray -// Trie System). For more info, see https://github.com/s-yata/darts-clone. -// -// This header file contains utils that build a darts-clone trie. To access such -// a darts-clone trie, use the utils from the companion header file -// darts_clone_trie_wrapper.h. #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_DARTS_CLONE_TRIE_BUILDER_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_DARTS_CLONE_TRIE_BUILDER_H_ -#include -#include -#include - -#include "absl/status/statusor.h" - -namespace tensorflow { -namespace text { -namespace trie_utils { - -// Builds the trie given keys and values, and returns the darts_clone trie -// array data. `keys` and `values` should have the same size; `values[i]` is the -// value for `keys[i]`. `keys` should not contain duplicated elements. In -// addition, the empty string "" should not be in `keys`, because darts_clone -// does not support that. Furthermore, all `values` should be non-negative. -absl::StatusOr> BuildDartsCloneTrie( - const std::vector& keys, const std::vector& values); - -// A variant where the values are indexes in the keys: i.e., the value for -// `keys[i]` is the index `i`. -absl::StatusOr> BuildDartsCloneTrie( - const std::vector& keys); - -} // namespace trie_utils -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/darts_clone_trie_builder.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_DARTS_CLONE_TRIE_BUILDER_H_ diff --git a/tensorflow_text/core/kernels/darts_clone_trie_test.cc b/tensorflow_text/core/kernels/darts_clone_trie_test.cc deleted file mode 100644 index a80c28353..000000000 --- a/tensorflow_text/core/kernels/darts_clone_trie_test.cc +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include "tensorflow_text/core/kernels/darts_clone_trie_builder.h" -#include "tensorflow_text/core/kernels/darts_clone_trie_wrapper.h" - -namespace tensorflow { -namespace text { -namespace trie_utils { - -using ::testing::status::StatusIs; - -TEST(DartsCloneTrieTest, CreateCursorPointToRootAndTryTraverseOneStep) { - // The test vocabulary. - std::vector vocab_tokens{"def", "\xe1\xb8\x8aZZ", "Abc"}; - - // Create the trie instance. - ASSERT_OK_AND_ASSIGN(std::vector trie_array, - BuildDartsCloneTrie(vocab_tokens)); - ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie, - DartsCloneTrieWrapper::Create(trie_array.data())); - - DartsCloneTrieWrapper::TraversalCursor cursor; - int data; - - cursor = trie.CreateTraversalCursorPointToRoot(); // Create a cursor to point - // to the root. - EXPECT_TRUE(trie.TryTraverseOneStep(cursor, 'A')); - EXPECT_FALSE(trie.TryGetData(cursor, data)); - EXPECT_TRUE(trie.TryTraverseOneStep(cursor, 'b')); - EXPECT_FALSE(trie.TryGetData(cursor, data)); - EXPECT_TRUE(trie.TryTraverseOneStep(cursor, 'c')); - EXPECT_TRUE(trie.TryGetData(cursor, data)); - EXPECT_THAT(data, 2); - EXPECT_FALSE(trie.TryTraverseOneStep(cursor, 'c')); -} - -TEST(DartsCloneTrieTest, CreateCursorAndTryTraverseSeveralSteps) { - // The test vocabulary. - std::vector vocab_tokens{"def", "\xe1\xb8\x8aZZ", "Abc"}; - - // Create the trie instance. - ASSERT_OK_AND_ASSIGN(std::vector trie_array, - BuildDartsCloneTrie(vocab_tokens)); - ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie, - DartsCloneTrieWrapper::Create(trie_array.data())); - - DartsCloneTrieWrapper::TraversalCursor cursor; - int data; - - cursor = trie.CreateTraversalCursor(trie.kRootNodeId); // Create a cursor to - // point to the root. - EXPECT_TRUE(trie.TryTraverseSeveralSteps(cursor, "def")); - EXPECT_TRUE(trie.TryGetData(cursor, data)); - EXPECT_THAT(data, 0); -} - -TEST(DartsCloneTrieTest, TraversePathNotExisted) { - // The test vocabulary. - std::vector vocab_tokens{"def", "\xe1\xb8\x8aZZ", "Abc"}; - - // Create the trie instance. - ASSERT_OK_AND_ASSIGN(std::vector trie_array, - BuildDartsCloneTrie(vocab_tokens)); - ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie, - DartsCloneTrieWrapper::Create(trie_array.data())); - - DartsCloneTrieWrapper::TraversalCursor cursor; - - trie.SetTraversalCursor( - cursor, - trie.kRootNodeId); // Use SetTraversalCursor() to point to the root. - EXPECT_FALSE(trie.TryTraverseSeveralSteps(cursor, "dez")); -} - -TEST(DartsCloneTrieTest, TraverseOnUtf8Path) { - // The test vocabulary. - std::vector vocab_tokens{"def", "\xe1\xb8\x8aZZ", "Abc"}; - - // Create the trie instance. - ASSERT_OK_AND_ASSIGN(std::vector trie_array, - BuildDartsCloneTrie(vocab_tokens)); - ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie, - DartsCloneTrieWrapper::Create(trie_array.data())); - - DartsCloneTrieWrapper::TraversalCursor cursor; - int data; - - trie.SetTraversalCursor( - cursor, - trie.kRootNodeId); // Use SetTraversalCursor() to point to the root. - EXPECT_TRUE(trie.TryTraverseSeveralSteps(cursor, "\xe1\xb8\x8aZZ")); - EXPECT_TRUE(trie.TryGetData(cursor, data)); - EXPECT_THAT(data, 1); -} - -TEST(DartsCloneTrieTest, TraverseOnPartialUtf8Path) { - // The test vocabulary. - std::vector vocab_tokens{"def", "\xe1\xb8\x8aZZ", "Abc"}; - - // Create the trie instance. - ASSERT_OK_AND_ASSIGN(std::vector trie_array, - BuildDartsCloneTrie(vocab_tokens)); - ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie, - DartsCloneTrieWrapper::Create(trie_array.data())); - - DartsCloneTrieWrapper::TraversalCursor cursor; - int data; - - trie.SetTraversalCursor( - cursor, - trie.kRootNodeId); // Use SetTraversalCursor() to point to the root. - EXPECT_TRUE(trie.TryTraverseSeveralSteps(cursor, "\xe1\xb8")); - EXPECT_FALSE(trie.TryGetData(cursor, data)); -} - -TEST(DartsCloneTrieTest, TraverseOnUtf8PathNotExisted) { - // The test vocabulary. - std::vector vocab_tokens{"def", "\xe1\xb8\x8aZZ", "Abc"}; - - // Create the trie instance. - ASSERT_OK_AND_ASSIGN(std::vector trie_array, - BuildDartsCloneTrie(vocab_tokens)); - ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie, - DartsCloneTrieWrapper::Create(trie_array.data())); - - DartsCloneTrieWrapper::TraversalCursor cursor; - - trie.SetTraversalCursor( - cursor, - trie.kRootNodeId); // Use SetTraversalCursor() to point to the root. - EXPECT_FALSE(trie.TryTraverseSeveralSteps(cursor, "\xe1\xb8\x84")); -} - -TEST(DartsCloneTrieBuildError, KeysValuesSizeDifferent) { - // The test vocabulary. - std::vector keys{"def", "\xe1\xb8\x8aZZ", "Abc"}; - std::vector values{1, 2, 3, 4}; - - // Create the trie instance. - ASSERT_THAT(BuildDartsCloneTrie(keys, values), - StatusIs(util::error::INVALID_ARGUMENT)); -} - -TEST(DartsCloneTrieBuildError, DuplicatedKeys) { - // The test vocabulary. - std::vector vocab_tokens{"def", "\xe1\xb8\x8aZZ", "Abc", "def"}; - - // Create the trie instance. - ASSERT_THAT(BuildDartsCloneTrie(vocab_tokens), - StatusIs(util::error::INVALID_ARGUMENT)); -} - -TEST(DartsCloneTrieBuildError, EmptyStringsInKeys) { - // The test vocabulary. - std::vector vocab_tokens{"def", "\xe1\xb8\x8aZZ", "Abc", ""}; - - // Create the trie instance. - ASSERT_THAT(BuildDartsCloneTrie(vocab_tokens), - StatusIs(util::error::INVALID_ARGUMENT)); -} - -TEST(DartsCloneTrieBuildError, NegativeValues) { - // The test vocabulary. - std::vector vocab_tokens{"def", "\xe1\xb8\x8aZZ", "Abc"}; - std::vector vocab_values{0, -1, 1}; - - // Create the trie instance. - ASSERT_THAT(BuildDartsCloneTrie(vocab_tokens, vocab_values), - StatusIs(util::error::INVALID_ARGUMENT)); -} - -} // namespace trie_utils -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/darts_clone_trie_wrapper.h b/tensorflow_text/core/kernels/darts_clone_trie_wrapper.h index 43067ec1b..371f28ef3 100644 --- a/tensorflow_text/core/kernels/darts_clone_trie_wrapper.h +++ b/tensorflow_text/core/kernels/darts_clone_trie_wrapper.h @@ -12,157 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Access utils for Darts-clone tries. -// -// Darts-clone is a compact and efficient implementation of Darts (Double-ARray -// Trie System). For more info, see https://github.com/s-yata/darts-clone. -// -// This header file contains utils that access a darts-clone trie. To build such -// a darts-clone trie, use the utils from the companion header file -// darts_clone_trie_builder.h. -// -// Note that although there is a 'traverse()' function in the original source -// (see https://github.com/s-yata/darts-clone/blob/master/include/darts.h), the -// utils in this header file are more efficient and the APIs are more flexible. #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_DARTS_CLONE_TRIE_WRAPPER_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_DARTS_CLONE_TRIE_WRAPPER_H_ -#include -#include - -#include "absl/status/statusor.h" - -namespace tensorflow { -namespace text { -namespace trie_utils { - -// A wrapper class of darts_clone trie for traversing and getting data on the -// trie. It does not own the actual 'trie_array'. -class DartsCloneTrieWrapper { - public: - // Represents the root node id. - static constexpr uint32_t kRootNodeId = 0; - - // A struct serving as the trie traversal cursor. It holds 'node_id' and - // 'unit' (which is 'trie_array_[node_id]'). The reason is to save and reuse - // the 'trie_array_[node_id]'. - struct TraversalCursor { - uint32_t node_id = 0; - uint32_t unit = 0; - }; - - // Constructs an instance by passing in the pointer to the trie array data. - // The caller needs to make sure that 'trie_array' points to a valid structure - // returned by darts_clone trie builder. The caller also needs to maintain the - // availability of 'trie_array' throughout the lifetime of this instance. - static absl::StatusOr Create( - const uint32_t* trie_array) { - if (trie_array == nullptr) { - return absl::InvalidArgumentError("trie_array is nullptr."); - } - return DartsCloneTrieWrapper(trie_array); - } - - // Creates a cursor pointing to the root. - TraversalCursor CreateTraversalCursorPointToRoot() { - return {kRootNodeId, trie_array_[kRootNodeId]}; - } - - // Creates a cursor pointing to the 'node_id'. - TraversalCursor CreateTraversalCursor(uint32_t node_id) { - return {node_id, trie_array_[node_id]}; - } - - // Sets the cursor to point to 'node_id'. - void SetTraversalCursor(TraversalCursor& cursor, uint32_t node_id) { - cursor.node_id = node_id; - cursor.unit = trie_array_[node_id]; - } - - // Traverses one step from 'cursor' following 'ch'. If successful (i.e., there - // exists such an edge), moves 'cursor' to the new node and returns true. - // Otherwise, does nothing (i.e., 'cursor' is not changed) and returns false. - bool TryTraverseOneStep(TraversalCursor& cursor, unsigned char ch) const { - const uint32_t next_node_id = cursor.node_id ^ offset(cursor.unit) ^ ch; - const uint32_t next_node_unit = trie_array_[next_node_id]; - if (label(next_node_unit) != ch) { - return false; - } - cursor.node_id = next_node_id; - cursor.unit = next_node_unit; - return true; - } - - // Traverses several steps from 'cursor' following the characters on 'path'. - // If *all* steps are successful, moves 'cursor' to the new node and returns - // true. Otherwise, does nothing (i.e., 'cursor' is not changed) and returns - // false. - bool TryTraverseSeveralSteps(TraversalCursor& cursor, - absl::string_view path) const { - return TryTraverseSeveralSteps(cursor, path.data(), path.size()); - } - - // If the node pointed by 'cursor' has data, read into 'out_data' and returns - // true; otherwise, does nothing and returns false. - bool TryGetData(const TraversalCursor& cursor, int& out_data) const { - if (!has_leaf(cursor.unit)) { - return false; - } - const uint32_t value_unit = - trie_array_[cursor.node_id ^ offset(cursor.unit)]; - out_data = value(value_unit); - return true; - } - - private: - // Use Create() instead of the constructor. - explicit DartsCloneTrieWrapper(const uint32_t* trie_array) - : trie_array_(trie_array) {} - - // The actual implementation of TryTraverseSeveralSteps. - bool TryTraverseSeveralSteps(TraversalCursor& cursor, const char* ptr, - int size) const { - uint32_t cur_id = cursor.node_id; - uint32_t cur_unit = cursor.unit; - for (; size > 0; --size, ++ptr) { - const unsigned char ch = static_cast(*ptr); - cur_id ^= offset(cur_unit) ^ ch; - cur_unit = trie_array_[cur_id]; - if (label(cur_unit) != ch) { - return false; - } - } - cursor.node_id = cur_id; - cursor.unit = cur_unit; - return true; - } - - // The helper functions below are based on - // https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h - - // Returns offset to children. - static uint32_t offset(uint32_t unit) { - return (unit >> 10) << ((unit & 0x200) >> 6); - } - - // Returns a label associated with a node. - // A leaf node will have the MSB set and thus return an invalid label. - static uint32_t label(uint32_t unit) { return unit & 0x800000ff; } - - // Returns whether a node has a leaf as a child. - static bool has_leaf(uint32_t unit) { return unit & 0x100; } - - // Returns a value associated with a node. Available when a node is a leaf. - static int value(uint32_t unit) { - return static_cast(unit & 0x7fffffff); - } - - // The pointer to the darts trie array. - const uint32_t* trie_array_; -}; - -} // namespace trie_utils -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/darts_clone_trie_wrapper.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_DARTS_CLONE_TRIE_WRAPPER_H_ diff --git a/tensorflow_text/core/kernels/disjoint_set_forest.h b/tensorflow_text/core/kernels/disjoint_set_forest.h index a0b163850..b152784b2 100644 --- a/tensorflow_text/core/kernels/disjoint_set_forest.h +++ b/tensorflow_text/core/kernels/disjoint_set_forest.h @@ -12,171 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_TEXT_CORE_KERNELS_DISJOINT_SET_FOREST_H_ -#define TENSORFLOW_TEXT_CORE_KERNELS_DISJOINT_SET_FOREST_H_ +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_DISJOINT_SET_FOREST_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_DISJOINT_SET_FOREST_H_ -#include +#include "tensorflow/core/kernels/text/disjoint_set_forest.h" -#include -#include - -#include "tensorflow/core/platform/logging.h" - -namespace tensorflow { -namespace text { - -// An implementation of the disjoint-set forest data structure. The universe of -// elements is the dense range of indices [0,n). Thread-compatible. -// -// By default, this uses the path compression and union by rank optimizations, -// achieving near-constant runtime on all operations. However, the user may -// disable the union by rank optimization, which allows the user to control how -// roots are selected when a union occurs. When union by rank is disabled, the -// runtime of all operations increases to O(log n) amortized. -// -// Template args: -// Index: An unsigned integral type wide enough to hold n. -// kUseUnionByRank: Whether to use the union by rank optimization. -template -class DisjointSetForest { - public: - static_assert(std::is_integral::value, "Index must be integral"); - static_assert(!std::is_signed::value, "Index must be unsigned"); - using IndexType = Index; - - // Creates an empty forest. - DisjointSetForest() = default; - - // Initializes this to hold the elements [0,|size|), each initially in its own - // singleton set. Replaces existing state, if any. - void Init(Index size); - - // Returns the root of the set containing |element|, which uniquely identifies - // the set. Note that the root of a set may change as the set is merged with - // other sets; do not cache the return value of FindRoot(e) across calls to - // Union() or UnionOfRoots() that could merge the set containing e. - Index FindRoot(Index element); - - // For convenience, returns true if |element1| and |element2| are in the same - // set. When performing a large batch of queries it may be more efficient to - // cache the value of FindRoot(), modulo caveats regarding caching above. - bool SameSet(Index element1, Index element2); - - // Merges the sets rooted at |root1| and |root2|, which must be the roots of - // their respective sets. Either |root1| or |root2| will be the root of the - // merged set. If |kUseUnionByRank| is true, then it is unspecified whether - // |root1| or |root2| will be the root; otherwise, |root2| will be the root. - void UnionOfRoots(Index root1, Index root2); - - // As above, but for convenience finds the root of |element1| and |element2|. - void Union(Index element1, Index element2); - - // The number of elements in this. - Index size() const { return size_; } - - private: - // The number of elements in the universe underlying the sets. - Index size_ = 0; - - // The parent of each element, where self-loops are roots. - std::vector parents_; - - // The rank of each element, for the union by rank optimization. Only used if - // |kUseUnionByRank| is true. - std::vector ranks_; -}; - -// Implementation details below. - -template -void DisjointSetForest::Init(Index size) { - size_ = size; - parents_.resize(size_); - if (kUseUnionByRank) ranks_.resize(size_); - - // Create singleton sets. - for (Index i = 0; i < size_; ++i) { - parents_[i] = i; - if (kUseUnionByRank) ranks_[i] = 0; - } -} - -template -Index DisjointSetForest::FindRoot(Index element) { - DCHECK_LT(element, size()); - Index *const __restrict parents = parents_.data(); - - // Walk up to the root of the |element|. Unroll the first two comparisons - // because path compression ensures most FindRoot() calls end there. In - // addition, if a root is found within the first two comparisons, then the - // path compression updates can be skipped. - Index current = element; - Index parent = parents[current]; - if (current == parent) return current; // |element| is a root - current = parent; - parent = parents[current]; - if (current == parent) return current; // |element| is the child of a root - do { // otherwise, continue upwards until root - current = parent; - parent = parents[current]; - } while (current != parent); - const Index root = current; - - // Apply path compression on the traversed nodes. - current = element; - parent = parents[current]; // not root, thanks to unrolling above - do { - parents[current] = root; - current = parent; - parent = parents[current]; - } while (parent != root); - - return root; -} - -template -bool DisjointSetForest::SameSet(Index element1, - Index element2) { - return FindRoot(element1) == FindRoot(element2); -} - -template -void DisjointSetForest::UnionOfRoots(Index root1, - Index root2) { - DCHECK_LT(root1, size()); - DCHECK_LT(root2, size()); - DCHECK_EQ(root1, parents_[root1]); - DCHECK_EQ(root2, parents_[root2]); - if (root1 == root2) return; // already merged - Index *const __restrict parents = parents_.data(); - - if (kUseUnionByRank) { - // Attach the lesser-rank root to the higher-rank root. - Index *const __restrict ranks = ranks_.data(); - const Index rank1 = ranks[root1]; - const Index rank2 = ranks[root2]; - if (rank2 < rank1) { - parents[root2] = root1; - } else if (rank1 < rank2) { - parents[root1] = root2; - } else { - // Equal ranks; choose one arbitrarily and promote its rank. - parents[root1] = root2; - ranks[root2] = rank2 + 1; - } - } else { - // Always make |root2| the root of the merged set. - parents[root1] = root2; - } -} - -template -void DisjointSetForest::Union(Index element1, - Index element2) { - UnionOfRoots(FindRoot(element1), FindRoot(element2)); -} - -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_TEXT_CORE_KERNELS_DISJOINT_SET_FOREST_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_DISJOINT_SET_FOREST_H_ diff --git a/tensorflow_text/core/kernels/disjoint_set_forest_test.cc b/tensorflow_text/core/kernels/disjoint_set_forest_test.cc deleted file mode 100644 index f63d92f08..000000000 --- a/tensorflow_text/core/kernels/disjoint_set_forest_test.cc +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/disjoint_set_forest.h" - -#include - -#include -#include -#include - -#include -#include - -namespace tensorflow { -namespace text { - -// Testing rig. -// -// Template args: -// Forest: An instantiation of the DisjointSetForest<> template. -template -class DisjointSetForestTest : public ::testing::Test { - protected: - using Index = typename Forest::IndexType; - - // Expects that the |expected_sets| and |forest| match. - void ExpectSets(const std::set> &expected_sets, - Forest *forest) { - std::set> expected_pairs; - for (const auto &expected_set : expected_sets) { - for (auto it = expected_set.begin(); it != expected_set.end(); ++it) { - for (auto jt = expected_set.begin(); jt != expected_set.end(); ++jt) { - expected_pairs.emplace(*it, *jt); - } - } - } - - for (Index lhs = 0; lhs < forest->size(); ++lhs) { - for (Index rhs = 0; rhs < forest->size(); ++rhs) { - if (expected_pairs.find({lhs, rhs}) != expected_pairs.end()) { - EXPECT_EQ(forest->FindRoot(lhs), forest->FindRoot(rhs)); - EXPECT_TRUE(forest->SameSet(lhs, rhs)); - } else { - EXPECT_NE(forest->FindRoot(lhs), forest->FindRoot(rhs)); - EXPECT_FALSE(forest->SameSet(lhs, rhs)); - } - } - } - } -}; - -using Forests = ::testing::Types< - DisjointSetForest, DisjointSetForest, - DisjointSetForest, DisjointSetForest, - DisjointSetForest, DisjointSetForest, - DisjointSetForest, DisjointSetForest>; -TYPED_TEST_SUITE(DisjointSetForestTest, Forests); - -TYPED_TEST(DisjointSetForestTest, DefaultEmpty) { - TypeParam forest; - EXPECT_EQ(0, forest.size()); -} - -TYPED_TEST(DisjointSetForestTest, InitEmpty) { - TypeParam forest; - forest.Init(0); - EXPECT_EQ(0, forest.size()); -} - -TYPED_TEST(DisjointSetForestTest, Populated) { - TypeParam forest; - forest.Init(5); - EXPECT_EQ(5, forest.size()); - this->ExpectSets({{0}, {1}, {2}, {3}, {4}}, &forest); - - forest.UnionOfRoots(1, 2); - this->ExpectSets({{0}, {1, 2}, {3}, {4}}, &forest); - - forest.Union(1, 2); - this->ExpectSets({{0}, {1, 2}, {3}, {4}}, &forest); - - forest.UnionOfRoots(0, 4); - this->ExpectSets({{0, 4}, {1, 2}, {3}}, &forest); - - forest.Union(3, 4); - this->ExpectSets({{0, 3, 4}, {1, 2}}, &forest); - - forest.Union(0, 3); - this->ExpectSets({{0, 3, 4}, {1, 2}}, &forest); - - forest.Union(2, 0); - this->ExpectSets({{0, 1, 2, 3, 4}}, &forest); - - forest.Union(1, 3); - this->ExpectSets({{0, 1, 2, 3, 4}}, &forest); -} - -// Testing rig for checking that when union by rank is disabled, the root of a -// merged set can be controlled. -class DisjointSetForestNoUnionByRankTest : public ::testing::Test { - protected: - using Forest = DisjointSetForest; - - // Expects that the roots of the |forest| match |expected_roots|. - void ExpectRoots(const std::vector &expected_roots, Forest *forest) { - ASSERT_EQ(expected_roots.size(), forest->size()); - for (uint32 i = 0; i < forest->size(); ++i) { - EXPECT_EQ(expected_roots[i], forest->FindRoot(i)); - } - } -}; - -TEST_F(DisjointSetForestNoUnionByRankTest, ManuallySpecifyRoot) { - Forest forest; - forest.Init(5); - ExpectRoots({0, 1, 2, 3, 4}, &forest); - - forest.UnionOfRoots(0, 1); // 1 is the root - ExpectRoots({1, 1, 2, 3, 4}, &forest); - - forest.Union(4, 3); // 3 is the root - ExpectRoots({1, 1, 2, 3, 3}, &forest); - - forest.Union(0, 2); // 2 is the root - ExpectRoots({2, 2, 2, 3, 3}, &forest); - - forest.Union(3, 3); // no effect - ExpectRoots({2, 2, 2, 3, 3}, &forest); - - forest.Union(4, 0); // 2 is the root - ExpectRoots({2, 2, 2, 2, 2}, &forest); -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/edit_changes.proto b/tensorflow_text/core/kernels/edit_changes.proto deleted file mode 100644 index 62d622b7a..000000000 --- a/tensorflow_text/core/kernels/edit_changes.proto +++ /dev/null @@ -1,15 +0,0 @@ -syntax = "proto2"; - -package tensorflow.text; - -// Protocol buffer for serializing a single icu::Edits object -// represented by a sequence of edit changes pairs: (old_length, new_length) -message EditChanges { - message Change { - optional int32 old_length = 1; - optional int32 new_length = 2; - } - - repeated Change change = 1; -} - diff --git a/tensorflow_text/core/kernels/exp_greedy_constrained_sequence_kernel_test.cc b/tensorflow_text/core/kernels/exp_greedy_constrained_sequence_kernel_test.cc deleted file mode 100644 index 0c91e24be..000000000 --- a/tensorflow_text/core/kernels/exp_greedy_constrained_sequence_kernel_test.cc +++ /dev/null @@ -1,854 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow_text/core/kernels/text_kernels_test_util.h" - -namespace tensorflow { - -using tensorflow::DT_INT32; -using tensorflow::FakeInput; -using tensorflow::NodeDefBuilder; -using tensorflow::Status; -using tensorflow::TensorShape; -using tensorflow::text_kernels_test_util::MatrixEq; -using tensorflow::text_kernels_test_util::VectorEq; - -class ExpGreedyConstrainedSequenceTest : public tensorflow::OpsTestBase { - public: - void SetUpOpWithDefaults() { - // Prepare graph. - TF_ASSERT_OK(NodeDefBuilder("tested_op", "ConstrainedSequence") - .Attr("Tin", DT_INT32) - .Attr("use_viterbi", false) - .Attr("use_log_space", false) - .Attr("use_start_and_end_states", true) - .Input(FakeInput()) - .Input(FakeInput()) - .Input(FakeInput()) - .Input(FakeInput()) - .Finalize(node_def())); - TF_ASSERT_OK(InitOp()); - } -}; - -// TODO(b/122968457): There are a bunch of tests that only validate !ok instead -// of looking for specific error messages; fix that. - -// This test examines evaluations with only a permissions matrix. -TEST_F(ExpGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNoWeights) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with an empty weights matrix not of rank 2. -TEST_F(ExpGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNonMatrixEmptyWeights) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with a 2D score matrix (implicit batch 1). -TEST_F(ExpGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithSingleBatchItem) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({1, 4}), // - { - 10.0, 12.0, 13.0, 4.0, // - }); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({1}), {1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // Validate the output. - std::vector expected_transitions({1}); - std::vector expected_offsets({0, 1}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines int64 input type and int32 output type. -TEST_F(ExpGreedyConstrainedSequenceTest, int64inint32out) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - // Validate the output. - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures the op can take a sequence length of type {{X},{Y},{Z}} -// (with an outer batch dimension). -TEST_F(ExpGreedyConstrainedSequenceTest, TwoDimensionalSequenceLengths) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3, 1}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures that final transitions that are forbidden by the permission -// matrix (final->null) are not taken. -TEST_F(ExpGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNoWeightsConstrainedByEnd) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, false, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok; the next - // highest is 1, but 1->OUT is not OK; the next highest is 0, which is OK. - // The second sequence's highest score is 3, OUT->3 is OK and 3->OUT is OK. - // The third sequence's highest score is 0, OUT->0 is OK and 0->OUT is OK. - // Validate the output. - std::vector expected_transitions({0, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with only a weight matrix. -TEST_F(ExpGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNoPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be multiplied by the last row in the weight tensor, so - // the 'real' scores are: - // 1: {1.0, 1.0, 3.5, 4.0} (max is 3) - // 2: {0.1, 4.5, 5.5, 5.0} (max is 2) - // 3: {10.0, 12.0, 1.5, 4.0} (max is 1) - // Validate the output. - std::vector expected_transitions({3, 2, 1}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with an empty not rank 2 permissions matrix. -TEST_F(ExpGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNonMatrixEmptyPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be multiplied by the last row in the weight tensor, so - // the 'real' scores are: - // 1: {1.0, 1.0, 3.5, 4.0} (max is 3) - // 2: {0.1, 4.5, 5.5, 5.0} (max is 2) - // 3: {10.0, 12.0, 1.5, 4.0} (max is 1) - // Validate the output. - std::vector expected_transitions({3, 2, 1}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures that final transitions are scored with the probability -// of ending the sequence on the transition (x->final->null). -TEST_F(ExpGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNoPermissionsWeightedByEnd) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 0.1, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be multiplied by the last row and the last column in the - // score tensor, so the real scores are: - // 1: {1.0, 1.0, 3.5, 0.4} (max is 2) - // 2: {0.1, 4.5, 5.5, 0.5} (max is 2) - // 3: {10.0, 12.0, 1.5, 0.4} (max is 1) - // Validate the output. - std::vector expected_transitions({2, 2, 1}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures that final transitions are not scored with the probability -// of ending the sequence on the transition (x->final->null) if -// use_start_and_end_states is False. -TEST_F(ExpGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNoPermissionsNotWeightedByEnd) { - // Prepare graph. - TF_ASSERT_OK(NodeDefBuilder("tested_op", "ConstrainedSequence") - .Attr("Tin", DT_INT32) - .Attr("use_viterbi", false) - .Attr("use_log_space", false) - .Attr("use_start_and_end_states", false) - .Input(FakeInput()) - .Input(FakeInput()) - .Input(FakeInput()) - .Input(FakeInput()) - .Finalize(node_def())); - TF_ASSERT_OK(InitOp()); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({4, 4}), {0.5, 0.5, 0.5, 0.5, // - 0.5, 0.5, 0.5, 0.5, // - 0.5, 0.5, 0.5, 0.5, // - 0.5, 0.5, 0.5, 0.5}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be multiplied by the last row and the last column in the - // score tensor, so the real scores are: - // 1: {5.0, 1.0, 3.5, 4.0} (max is 0) - // 2: {.5, 4.5, 5.5, 2.5} (max is 2) - // 3: {50.0, 12.0, 1.5,2.0} (max is 0) - // Validate the output. - std::vector expected_transitions({0, 2, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with both weight and permission matrices. -TEST_F(ExpGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithWeightsAndPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 'OUTSIDE' - true, false, true, true, false, // FROM 'NULL' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be multiplied by the last row in the weight tensor, so - // the 'real' scores are: - // 1: {1.0, 1.0, 3.5, 4.0} (max is 3). OUT->3 is OK. - // 2: {0.1, 4.5, 5.5, 5.0} (max is 2). OUT->2 is OK. - // 3: {10.0, 12.0, 1.5, 4.0} (max is 1). OUT->1 is not OK, so go with 0. - // Note that X->OUT is set to always be OK here. - std::vector expected_transitions({3, 2, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines multiple evaluations with both weight and permission -// matrices. -TEST_F(ExpGreedyConstrainedSequenceTest, - ComputesMultipleTransitionsWithWeightsAndPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 2, 4}), // - {{ - 10.0, 12.0, 7.0, 4.0, // Batch 0, step 0 - 10.0, 10.0, 10.0, 10.0, // Batch 0, step 1 - 1.0, 9.0, 11.0, 5.0, // Batch 1, step 0 - 10.0, 15.0, 1.0, 12.0, // Batch 1, step 1 - 100.0, 24.0, 3.0, 4.0, // Batch 2, step 0 - 1.0, 11.0, 1.0, 10.0, // Batch 2, step 1 - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {2, 2, 2}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO NUL - true, true, true, true, true, // FROM 0 - true, true, true, true, false, // FROM 1 - true, false, true, false, true, // FROM 2 - true, true, true, true, true, // FROM 3 (OUT) - true, false, true, true, true, // FROM 'NULL' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // 0 - 0.5, 0.5, 0.5, 0.5, 1.0, // 1 - 0.5, 0.5, 0.5, 0.5, 1.0, // 2 - 0.5, 0.5, 1.0, 0.5, 1.0, // 3 - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // STEP 1: - // All scores should be multiplied by the last row in the weight tensor, so - // the 'real' scores are: - // 1: {0.1 6.0 3.5 4.0} (max is 3). OUT->3 is OK. - // 2: {0.1, 4.5, 5.5, 5.0} (max is 2). OUT->2 is OK. - // 3: {10.0, 12.0, 1.5, 4.0} (max is 1). OUT->1 is not OK, so go with 0. - // STEP 2: - // 1: In state '3', so use row 3 in the weight tensor. - // Weights are {5, 5, 10, 5}; 3->2 is OK and 2->OUT is OK; use 2. - // 2: In state '2', so use row 2 in the weight tensor. - // Weights are {5, 7.5, .5, 6.0}; 2->3 is not OK and 2->1 is not OK, so 0. - // 3: In state 0, so use row 0 in the weight tensor. - // Weights are {0.5, 5.5, 0.5, 5}; 0->1 is OK but 1->OUT is not, so 3. - - std::vector expected_transitions({3, 2, 2, 0, 0, 3}); - std::vector expected_offsets({0, 2, 4, 6}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} -// This test examines multiple evaluations with both weight and permission -// matrices. -TEST_F(ExpGreedyConstrainedSequenceTest, - ComputesMultipleTransitionsWithVaryingLengths) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 2, 4}), // - {{ - 10.0, 12.0, 7.0, 4.0, // Batch 0, step 0 - 10.0, 10.0, 10.0, 10.0, // Batch 0, step 1 - 1.0, 9.0, 11.0, 5.0, // Batch 1, step 0 - 10.0, 15.0, 1.0, 12.0, // Batch 1, step 1 - 100.0, 24.0, 3.0, 4.0, // Batch 2, step 0 - 1.0, 11.0, 1.0, 10.0, // Batch 2, step 1 - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {2, 1, 2}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO NUL - true, true, true, true, true, // FROM 0 - true, true, true, true, false, // FROM 1 - true, false, true, false, true, // FROM 2 - true, true, true, true, true, // FROM 3 (OUT) - true, false, true, true, true, // FROM 'NULL' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // 0 - 0.5, 0.5, 0.5, 0.5, 1.0, // 1 - 0.5, 0.5, 0.5, 0.5, 1.0, // 2 - 0.5, 0.5, 1.0, 0.5, 1.0, // 3 - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // STEP 1: - // All scores should be multiplied by the last row in the weight tensor, so - // the 'real' scores are: - // 1: {0.1 6.0 3.5 4.0} (max is 3). OUT->3 is OK. - // 2: {0.1, 4.5, 5.5, 5.0} (max is 2). OUT->2 and 2->OUT are OK. - // 3: {10.0, 12.0, 1.5, 4.0} (max is 1). OUT->1 is not OK, so go with 0. - // STEP 2: - // 1: In state '3', so use row 3 in the weight tensor. - // Weights are {5, 5, 10, 5}; 3->2 is OK and 2->OUT is OK; use 2. - // 2: End of sequence; no change. - // 3: In state 0, so use row 0 in the weight tensor. - // Weights are {0.5, 5.5, 0.5, 5}; 0->1 is OK but 1->OUT is not, so 3. - - std::vector expected_transitions({3, 2, 2, 0, 3}); - std::vector expected_offsets({0, 2, 3, 5}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with a fully negative input set. -TEST_F(ExpGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNegativeInputs) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - -10.0, -12.0, -13.0, -4.0, // - -1.0, -12.0, -13.0, -14.0, // - -15.0, -2.0, -3.0, -14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, true, true, true, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - std::vector expected_transitions({3, 0, 1}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with an all-zero weight matrix. -TEST_F(ExpGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithZeroedWeights) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), { - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, - }); - - TF_ASSERT_OK(RunOpKernel()); - - // In the case of a tie between weights, the higher state number wins; - // if all weights are zero, the states should all be 3. - - std::vector expected_transitions({3, 3, 3}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -TEST_F(ExpGreedyConstrainedSequenceTest, - ImpossibleSequencesResultInNegativeOnesIfAttrIsSet) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 2, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {2, 2, 2}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - false, false, false, false, false, // FROM 0 - false, false, false, false, false, // FROM 1 - false, false, false, false, false, // FROM 2 - false, false, false, false, false, // FROM 3 - false, false, false, false, false, // FROM 'OUT' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // Validate the output. - - std::vector expected_transitions({-1, -1, -1, -1, -1, -1}); - std::vector expected_offsets({0, 2, 4, 6}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures the op will throw an error if there are too few scores to -// finalize all the sequences. -TEST_F(ExpGreedyConstrainedSequenceTest, ErrorsIfGivenInsufficientScores) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 2, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} - -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/exp_viterbi_constrained_sequence_kernel_test.cc b/tensorflow_text/core/kernels/exp_viterbi_constrained_sequence_kernel_test.cc deleted file mode 100644 index 49cfa02be..000000000 --- a/tensorflow_text/core/kernels/exp_viterbi_constrained_sequence_kernel_test.cc +++ /dev/null @@ -1,910 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow_text/core/kernels/text_kernels_test_util.h" - -namespace tensorflow { - -using tensorflow::DT_INT32; -using tensorflow::FakeInput; -using tensorflow::NodeDefBuilder; -using tensorflow::Status; -using tensorflow::TensorShape; -using tensorflow::text_kernels_test_util::MatrixEq; -using tensorflow::text_kernels_test_util::VectorEq; - -class ExpViterbiConstrainedSequenceTest : public tensorflow::OpsTestBase { - public: - void SetUpOpWithDefaults() { - // Prepare graph. - TF_ASSERT_OK(NodeDefBuilder("tested_op", "ConstrainedSequence") - .Attr("Tin", DT_INT32) - .Attr("use_viterbi", true) - .Attr("use_log_space", false) - .Attr("use_start_and_end_states", true) - .Input(FakeInput()) - .Input(FakeInput()) - .Input(FakeInput()) - .Input(FakeInput()) - .Finalize(node_def())); - TF_ASSERT_OK(InitOp()); - } -}; - -// TODO(b/122968457): There are a bunch of tests that only validate !ok instead -// of looking for specific error messages; fix that. - -// This test examines evaluations with only a permissions matrix. -TEST_F(ExpViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithNoWeights) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with an empty weights matrix not of rank 2. -TEST_F(ExpViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithNonMatrixEmptyWeights) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with a 2D score matrix (implicit batch 1). -TEST_F(ExpViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithSingleBatchItem) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({1, 4}), // - { - 10.0, 12.0, 13.0, 4.0, // - }); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({1}), {1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // Validate the output. - std::vector expected_transitions({1}); - std::vector expected_offsets({0, 1}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines int64 input type and int32 output type. -TEST_F(ExpViterbiConstrainedSequenceTest, int64inint32out) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - // Validate the output. - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures the op can take a sequence length of type {{X},{Y},{Z}} -// (with an outer batch dimension). -TEST_F(ExpViterbiConstrainedSequenceTest, TwoDimensionalSequenceLengths) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3, 1}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures that final transitions that are forbidden by the permission -// matrix (final->null) are not taken. -TEST_F(ExpViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithNoWeightsConstrainedByEnd) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, false, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok; the next - // highest is 1, but 1->OUT is not OK; the next highest is 0, which is OK. - // The second sequence's highest score is 3, OUT->3 is OK and 3->OUT is OK. - // The third sequence's highest score is 0, OUT->0 is OK and 0->OUT is OK. - // Validate the output. - std::vector expected_transitions({0, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with only a weight matrix. -TEST_F(ExpViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithNoPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be multiplied by the last row in the weight tensor, so - // the 'real' scores are: - // 1: {1.0, 1.0, 3.5, 4.0} (max is 3) - // 2: {0.1, 4.5, 5.5, 5.0} (max is 2) - // 3: {10.0, 12.0, 1.5, 4.0} (max is 1) - // Validate the output. - std::vector expected_transitions({3, 2, 1}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with an empty not rank 2 permissions matrix. -TEST_F(ExpViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithNonMatrixEmptyPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be multiplied by the last row in the weight tensor, so - // the 'real' scores are: - // 1: {1.0, 1.0, 3.5, 4.0} (max is 3) - // 2: {0.1, 4.5, 5.5, 5.0} (max is 2) - // 3: {10.0, 12.0, 1.5, 4.0} (max is 1) - // Validate the output. - std::vector expected_transitions({3, 2, 1}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures that final transitions are scored with the probability -// of ending the sequence on the transition (x->final->null). -TEST_F(ExpViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithNoPermissionsWeightedByEnd) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 0.1, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be multiplied by the last row and the last column in the - // score tensor, so the real scores are: - // 1: {1.0, 1.0, 3.5, 0.4} (max is 2) - // 2: {0.1, 4.5, 5.5, 0.5} (max is 2) - // 3: {10.0, 12.0, 1.5, 0.4} (max is 1) - // Validate the output. - std::vector expected_transitions({2, 2, 1}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures that final transitions are not scored with the probability -// of ending the sequence on the transition (x->final->null) if -// use_start_and_end_states is False. -TEST_F(ExpViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithNoPermissionsNotWeightedByEnd) { - // Prepare graph. - TF_ASSERT_OK(NodeDefBuilder("tested_op", "ConstrainedSequence") - .Attr("Tin", DT_INT32) - .Attr("use_viterbi", true) - .Attr("use_log_space", false) - .Attr("use_start_and_end_states", false) - .Input(FakeInput()) - .Input(FakeInput()) - .Input(FakeInput()) - .Input(FakeInput()) - .Finalize(node_def())); - TF_ASSERT_OK(InitOp()); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({4, 4}), {0.5, 0.5, 0.5, 0.5, // - 0.5, 0.5, 0.5, 0.5, // - 0.5, 0.5, 0.5, 0.5, // - 0.5, 0.5, 0.5, 0.5}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be multiplied by the last row and the last column in the - // score tensor, so the real scores are: - // 1: {5.0, 1.0, 3.5, 4.0} (max is 0) - // 2: {.5, 4.5, 5.5, 2.5} (max is 2) - // 3: {50.0, 12.0, 1.5,2.0} (max is 0) - // Validate the output. - std::vector expected_transitions({0, 2, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with both weight and permission matrices. -TEST_F(ExpViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithWeightsAndPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 'OUTSIDE' - true, false, true, true, false, // FROM 'NULL' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be multiplied by the last row in the weight tensor, so - // the 'real' scores are: - // 1: {1.0, 1.0, 3.5, 4.0} (max is 3). OUT->3 is OK. - // 2: {0.1, 4.5, 5.5, 5.0} (max is 2). OUT->2 is OK. - // 3: {10.0, 12.0, 1.5, 4.0} (max is 1). OUT->1 is not OK, so go with 0. - // Note that X->OUT is set to always be OK here. - std::vector expected_transitions({3, 2, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines multiple evaluations with both weight and permission -// matrices. -TEST_F(ExpViterbiConstrainedSequenceTest, - ComputesMultipleTransitionsWithWeightsAndPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 2, 4}), // - {{ - 10.0, 12.0, 7.0, 4.0, // Batch 0, step 0 - 10.0, 10.0, 10.0, 10.0, // Batch 0, step 1 - 1.0, 9.0, 11.0, 5.0, // Batch 1, step 0 - 10.0, 15.0, 1.0, 12.0, // Batch 1, step 1 - 100.0, 24.0, 3.0, 4.0, // Batch 2, step 0 - 1.0, 11.0, 1.0, 10.0, // Batch 2, step 1 - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {2, 2, 2}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO NUL - true, true, true, true, true, // FROM 0 - true, true, true, true, false, // FROM 1 - true, false, true, false, true, // FROM 2 - true, true, true, true, true, // FROM 3 (OUT) - true, false, true, true, true, // FROM 'NULL' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // 0 - 0.5, 0.5, 0.5, 0.5, 1.0, // 1 - 0.5, 0.5, 0.5, 0.5, 1.0, // 2 - 0.5, 0.5, 1.0, 0.5, 1.0, // 3 - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // STEP 1: - // All scores should be multiplied by the last row in the weight tensor, so - // the 'real' scores are: - // B0: { 1.0, [NOTOK], 3.5, 4.0} - // B1: { 0.1, [NOTOK], 5.5, 5.0} - // B2: {10.0, [NOTOK], 1.5, 4.0} - // - // STEP 2: - // (Forbidden transitions are marked with '*') - // - // BATCH 0: - // Raw scores are: {10.0, 10.0, 10.0, 10.0} - // from 0: New scores are {5.0, 5.0, 5.0, 5.0}, totals: {5, 0, 17.5, 20} - // from 1: New scores are {5.0, 5.0, 0*, 5.0}, totals: {5, 0, 0, 20} - // from 2: New scores are {5.0, 5.0, 5.0, 10.0}, totals: {5, 0, 17.5, 40} - // from 3: New scores are {5.0, 5.0, 0*, 5.0}, totals: {5, 0, 0, 20} - // Top scores are 20, 20, 40, 20 from [3, 3, 3, 3]. - // 1->OUT is not valid. - // Final scores are [20, 0, 40, 20] for a - // final state of [2] with a sequence of [3->2]. - // - // BATCH 1: - // Raw scores are {10, 15, 1, 12} - // from 0: Weighted score is {5, 5, 5, 5}, totals: {0.5, 0, 27.5, 25} - // from 1: Weighted score is {7.5, 7.5, 0*, 7.5}, t: {0.75, 0, 0, 37.5} - // from 2: Weighted score is {0.5, 0.5, 0.5, 1.0}, t: {0.05, 0, 2.75, 5} - // from 3: Weighted score is {6, 6, 0*, 6}, totals: {0.6, 0, 0, 30} - // Top scores are {27.5, 37.5, 5, 30} from [2, 3, 3, 3] - // 1->OUT is not valid, so final scores are [27.5, 0, 5, 30] for a final - // state of [3] and a sequence of [3, 3] - // - // BATCH 2: - // Raw scores are {1.0, 11.0, 1.0, 10.0} - // 2/0: Weighted score is {.5, .5, .5, .5}. t: {5, 0, 0.75, 2} - // 2/1: Weighted score is {5.5, 5.5, 0*, 5.5}. t: {55, 0, 0, 22} - // 2/2: Weighted score is {.5, .5, .5, 1.0}. t: {5, 0, 0.75, 4} - // 2/3: Weighted score is {5, 5, 0*, 5}. t: {50, 0, 0, 20} - // Top scores are {5, 55, 5, 50} from [0, 0, 0, 0] - // 1->OUT is not valid, so final scores are [5, 0, 5, 50] for a final - // state of 3 and a sequence of [0, 3]. - - std::vector expected_transitions({3, 2, 3, 3, 0, 3}); - std::vector expected_offsets({0, 2, 4, 6}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines multiple evaluations with both weight and permission -// matrices. -TEST_F(ExpViterbiConstrainedSequenceTest, - ComputesMultipleTransitionsWithVaryingLengths) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 2, 4}), // - {{ - 10.0, 12.0, 7.0, 4.0, // Batch 0, step 0 - 10.0, 10.0, 10.0, 10.0, // Batch 0, step 1 - 1.0, 9.0, 11.0, 5.0, // Batch 1, step 0 - 10.0, 15.0, 1.0, 12.0, // Batch 1, step 1 - 100.0, 24.0, 3.0, 4.0, // Batch 2, step 0 - 1.0, 11.0, 1.0, 10.0, // Batch 2, step 1 - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {2, 1, 2}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO NUL - true, true, true, true, true, // FROM 0 - true, true, true, true, false, // FROM 1 - true, false, true, false, true, // FROM 2 - true, true, true, true, true, // FROM 3 (OUT) - true, false, true, true, true, // FROM 'NULL' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // 0 - 0.5, 0.5, 0.5, 0.5, 1.0, // 1 - 0.5, 0.5, 0.5, 0.5, 1.0, // 2 - 0.5, 0.5, 1.0, 0.5, 1.0, // 3 - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // STEP 1: - // All scores should be multiplied by the last row in the weight tensor, so - // the 'real' scores are: - // B0: { 1.0, [NOTOK], 3.5, 4.0} - // B1: { 0.1, [NOTOK], 5.5, 5.0} - // B2: {10.0, [NOTOK], 1.5, 4.0} - // - // STEP 2: - // (Forbidden transitions are marked with '*') - // - // BATCH 0: - // Raw scores are: {10.0, 10.0, 10.0, 10.0} - // from 0: New scores are {5.0, 5.0, 5.0, 5.0}, totals: {5, 0, 17.5, 20} - // from 1: New scores are {5.0, 5.0, 0*, 5.0}, totals: {5, 0, 0, 20} - // from 2: New scores are {5.0, 5.0, 5.0, 10.0}, totals: {5, 0, 17.5, 40} - // from 3: New scores are {5.0, 5.0, 0*, 5.0}, totals: {5, 0, 0, 20} - // Top scores are 20, 20, 40, 20 from [3, 3, 3, 3]. - // 1->OUT is not valid. - // Final scores are [20, 0, 40, 20] for a - // final state of [2] with a sequence of [3->2]. - // - // BATCH 1: - // End of sequence; no further action. - // - // BATCH 2: - // Raw scores are {1.0, 11.0, 1.0, 10.0} - // 2/0: Weighted score is {.5, .5, .5, .5}. t: {5, 0, 0.75, 2} - // 2/1: Weighted score is {5.5, 5.5, 0*, 5.5}. t: {55, 0, 0, 22} - // 2/2: Weighted score is {.5, .5, .5, 1.0}. t: {5, 0, 0.75, 4} - // 2/3: Weighted score is {5, 5, 0*, 5}. t: {50, 0, 0, 20} - // Top scores are {5, 55, 5, 50} from [0, 0, 0, 0] - // 1->OUT is not valid, so final scores are [5, 0, 5, 50] for a final - // state of 3 and a sequence of [0, 3]. - - std::vector expected_transitions({3, 2, 2, 0, 3}); - std::vector expected_offsets({0, 2, 3, 5}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with an all-zero weight matrix. -TEST_F(ExpViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithZeroedWeights) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), { - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, - }); - - TF_ASSERT_OK(RunOpKernel()); - - // In the case of a tie between weights, the higher state number wins; - // if all weights are zero, the states should all be 3. - - std::vector expected_transitions({3, 3, 3}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -TEST_F(ExpViterbiConstrainedSequenceTest, - ImpossibleSequencesResultInNegativeOnesIfAttrIsSet) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 2, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {2, 2, 2}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - false, false, false, false, false, // FROM 0 - false, false, false, false, false, // FROM 1 - false, false, false, false, false, // FROM 2 - false, false, false, false, false, // FROM 3 - false, false, false, false, false, // FROM 'OUT' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // Validate the output. - - std::vector expected_transitions({-1, -1, -1, -1, -1, -1}); - std::vector expected_offsets({0, 2, 4, 6}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures the op will throw an error if there are too few scores to -// finalize all the sequences. -TEST_F(ExpViterbiConstrainedSequenceTest, ErrorsIfGivenInsufficientScores) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 2, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} - -// This test ensures that the op correctly outputs a ragged tensor with type -// int32 -TEST_F(ExpViterbiConstrainedSequenceTest, OutputsInt32RaggedTensor) { - // Prepare graph. - SetUpOpWithDefaults(); - - AddInputFromArray( - TensorShape({3, 2, 4}), // - {{ - 10.0, 12.0, 7.0, 4.0, // Tr. to 3 - 10.0, 10.0, 10.0, 10.0, // Tr. 3 to 2 on wt. - 1.0, 9.0, 11.0, 5.0, // Tr. to 2 - 10.0, 15.0, 1.0, 12.0, // Irrelevant (past end of sequence) - 100.0, 24.0, 3.0, 4.0, // Tr. to 0 - 1.0, 10.0, 1.0, 10.0, // Tr. 0 to 3 (1 cannot tr. to NULL) - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {2, 1, 2}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO NUL - true, true, true, true, true, // FROM 0 - true, true, true, true, false, // FROM 1 - true, false, true, false, true, // FROM 2 - true, true, true, true, true, // FROM 3 (OUT) - true, false, true, true, true, // FROM 'NULL' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // 0 - 0.5, 0.5, 0.5, 0.5, 1.0, // 1 - 0.5, 0.5, 0.5, 0.5, 1.0, // 2 - 0.5, 0.5, 1.0, 0.5, 1.0, // 3 - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - std::vector expected_transitions({3, 2, 2, 0, 3}); - std::vector expected_offsets({0, 2, 3, 5}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/fast_bert_normalizer.h b/tensorflow_text/core/kernels/fast_bert_normalizer.h index efd5102bf..52721e97a 100644 --- a/tensorflow_text/core/kernels/fast_bert_normalizer.h +++ b/tensorflow_text/core/kernels/fast_bert_normalizer.h @@ -15,325 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_H_ -#include -#include - -#include "absl/strings/string_view.h" -#include "icu4c/source/common/unicode/utf8.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow_text/core/kernels/darts_clone_trie_wrapper.h" -#include "tensorflow_text/core/kernels/fast_bert_normalizer_model_generated.h" - -namespace tensorflow { -namespace text { -namespace text_norm { - -// Bit configurations to encode the mapped normalized value. Currently, -// - The 1st bit (from the left) is reserved by Darts-clone Trie. -// - The 2nd bit stores whether the normalized string is different from the -// codepoint itself. It is also used to differentiate from value 0, which is -// the value returned by `LookupData()` when the codepoint is not stored on -// the trie. -// - The next 24 bits (3 to 26) encode the offset of the normalized string in -// a shared pool. -// - The last 6 bits (27 to 32) encode the length of utf8 bytes of the -// normalized string. - -// The 2rd bit stores whether the normalized string is different from itself. -static constexpr unsigned int kIsNormalizedStringDifferentMask = 0x40000000; - -// Number of lowest bits to represent the length of utf8 bytes of mapped -// values. 6-bit is enough to encode the length of the normalized strings. -static constexpr unsigned int kBitsToEncodeUtf8LengthOfNormalizedString = 6; - -// The mask for getting the length of the normalized string. It equals to 0x3F -// when `kBitsToEncodeUtf8LengthOfNormalizedString = 6`. -static constexpr unsigned int kNormalizedStringLengthMask = - (1 << kBitsToEncodeUtf8LengthOfNormalizedString) - 1; - -// Maximum length of utf8 bytes of normalized strings. It equals to 63 -// when `kBitsToEncodeUtf8LengthOfNormalizedString = 6`. -static constexpr unsigned int kMaximumUtf8LengthOfNormalizedString = - (1 << kBitsToEncodeUtf8LengthOfNormalizedString) - 1; - -// The mask for getting the offset of the normalized string in the pool. It -// equals to 0x3FFFFFC0 when `kBitsToEncodeUtf8LengthOfNormalizedString = 6`. -static constexpr unsigned int kNormalizedStringOffsetMask = - (kIsNormalizedStringDifferentMask - 1) ^ kNormalizedStringLengthMask; - -// Each normalized string is represented as a continuous utf-8 substring in a -// pool. `kMaximumOffsetOfNormalizedString` denotes the maximum offset supported -// here. -static constexpr unsigned int kMaximumOffsetOfNormalizedString = - (1 << (32 - 2 - kBitsToEncodeUtf8LengthOfNormalizedString)) - 1; - -} // namespace text_norm - -// A fast text normalizer for BERT based on codepoint-wise mappings. -class FastBertNormalizer { - public: - // Creates an instance. - // - // Args: - // * trie_data: the pointer to the trie data, which is not owned by this - // instance and should be kept alive through the lifetime of the instance. - // * data_for_codepoint_zero: the mapped data for the codepoint zero. - // * normalized_string_pool: the pointer to the normalized string pool data, - // which is not owned by this instance and should be kept alive through the - // lifetime of the instance. - static absl::StatusOr Create( - const uint32_t* trie_data, int data_for_codepoint_zero, - const char* normalized_string_pool) { - FastBertNormalizer result; - SH_ASSIGN_OR_RETURN(auto trie, - trie_utils::DartsCloneTrieWrapper::Create(trie_data)); - result.trie_ = - std::make_unique(std::move(trie)); - result.data_for_codepoint_zero_ = data_for_codepoint_zero; - result.normalized_string_pool_ = - reinterpret_cast(normalized_string_pool); - return result; - } - - // Creates an instance. - // - // Args: - // * model_flatbuffer: the pointer to the FastBertNormalizerModel - // flatbuffer, which is not owned by this instance and should be kept alive - // through the lifetime of the instance. - static absl::StatusOr Create( - const void* model_flatbuffer) { - // `GetFastBertNormalizerModel()` is autogenerated by flatbuffer. - auto model = GetFastBertNormalizerModel(model_flatbuffer); - return Create( - model->trie_array()->data(), model->data_for_codepoint_zero(), - reinterpret_cast(model->normalized_string_pool()->data())); - } - - // Normalizes the input based on config `lower_case_nfd_strip_accents`. - // - // It keeps track that, for each byte in the normalized string, which position - // in the original input it should best map to (see below notes). - // - // Here are a few examples (assuming `lower_case_nfd_strip_accents=true`): - // * Input: "ABC" - // Output: "abc" - // Mapping: 0,1,2,3 - // Explanation: "A" -> "a", "B" -> "b", "C" -> "c". The start position of - // "a" maps to position 0 in the input; its exclusive end position equals to - // the start position of "b", which maps to position 1 in the input. The - // start position of "c" maps to position 2 in the input. The exclusive end - // position of "c" (which is also the end of the normalized string) maps to - // position 3 in the input (i.e., the end of input). - // * Input: "B\x41\xCC\x80C" - // Output: "bac" - // Mapping: 0,1,4,5 - // Explanation: "\x41\xCC\x80" -> "a". So the start position of "a" maps to - // position 1 in the input; the exclusive end position of "a" (which is also - // the start position of "c") is position 4 in the input. The exclusive end - // position of "c" (which is also the end of the normalized string) maps to - // position 5 in the input (i.e., the end of input). - // * Input: "a\xCE\x89" - // Output: "a\xCE\xB7" - // Mapping: 0,1,1,3 - // Explanation: "\xCE\xB9" (2 bytes) -> "\xCE\xB7" (2 bytes). Because - // "\xCE\xB7" represents the normalized string of the codepoint U+0389 (i.e. - // "\xCE\x90"), their start positions both map to position 1 in the input - // (which is the start position of that codepoint). - // * Input: "a\xC2\xBC" - // Output = "a1\xE2\x81\x84""4" - // Mapping: 0,1,1,1,1,1,3 - // Explanation: "\xC2\xBC" (2 bytes) -> "1\xE2\x81\x84""4" (5 bytes). The - // start points of those 5 bytes all point to position 1 in the input, which - // is the start position of that codepoint. - // - // Note that if the input character is not changed after normalization, the - // bytes are mapped to their original byte locations. For example: - // * Input: "a\xCC\x80" - // Output: "a\xCC\x80" - // Mapping: 0,1,2,3 - // However, if a multibyte character is changed after normalization, all bytes - // of the result character map to the first byte of the character in the - // input. - // * Input: "a\xCE\x89" - // Output: "a\xCE\xB7" - // Mapping: 0,1,1,3 - // The reasons are two-folds: - // 1. When a multibyte character is changed after normalizatoon, it is not - // always feasible to map every internal byte in the output back to their - // corresponding byte in the input. For example, consider the cases where - // 2-bytes are normalized to 3-bytes or vice versa. - // 2. The mapping of the internal bytes in the normalized text is usually not - // used, because users work with UTF-8 output in unit of codepoints, and only - // the mapping of the first byte is important. - // - // - // This function does not check whether the input is valid utf-8. This - // behavior is consistent with the existing TF.Text::BertTokenizer. - // - // Args: - // * input_text: The input text. - // * is_output_identical_as_input: True if the normalized string is the - // same as the input. In this case, `output_normalized_text` is empty and - // `output_normalized_offset_mapping` is not changed. - // * output_normalized_text: The normalized text. - // * output_normalized_offset_mapping: In addition to the existing content, - // the extended new content has size 1 plus the size of `normalized_text`. - // Each value is the mapped offset of each byte of `normalized_text` in the - // original `input_text`. The final value maps the end of `normalized_text` - // to the end of `input_text`. - template - void NormalizeText(absl::string_view input_text, - bool* is_output_identical_as_input, - std::string* output_normalized_text, - std::vector* output_normalized_offset_mapping) const { - *output_normalized_text = ""; - // `output_normalized_offset_mapping` is not cleared so the existing content - // is kept. - int last_pos_to_copy_over = 0; // Mark where the copy stopped last time. - auto copy_unchanged_input_to_output = - [input_text, output_normalized_text, output_normalized_offset_mapping, - &last_pos_to_copy_over](int exclusive_copy_end) { - // Copy from `last_pos_to_copy_over` to `exclusive_copy_end` and - // update `last_pos_to_copy_over` accordingly. - if (last_pos_to_copy_over < exclusive_copy_end) { - absl::StrAppend( - output_normalized_text, - input_text.substr(last_pos_to_copy_over, - exclusive_copy_end - last_pos_to_copy_over)); - if constexpr (kGetOffsets) { - for (int i = last_pos_to_copy_over; i < exclusive_copy_end; ++i) { - output_normalized_offset_mapping->push_back(i); - } - } - last_pos_to_copy_over = exclusive_copy_end; - } - }; - int cur_pos = 0; // Current position in `input_text` to process. - while (cur_pos < input_text.size()) { - int next_pos = cur_pos; - U8_FWD_1(input_text.data(), next_pos, input_text.size()); - const int cp_byte_length = next_pos - cur_pos; - if (cp_byte_length == 0) { - // The codepoint here has length 0, which is probably invalid UTF-8. - // Copy the remaining unchanged text if any. - copy_unchanged_input_to_output(cur_pos); - // Output a whitespace here to replace the invalid UTF-8 byte. - absl::StrAppend(output_normalized_text, " "); - if constexpr (kGetOffsets) { - output_normalized_offset_mapping->push_back(cur_pos); - } - // Move by one byte. - ++cur_pos; - // Mark the next position to copy over. - last_pos_to_copy_over = cur_pos; - continue; - } - const int encoded_data = - LookupData(input_text.substr(cur_pos, cp_byte_length)); - if (!IsNormalizedStringDifferent(encoded_data)) { - // The codepoint is the same as the normalized. We skip here and copy - // over in an aggregation way for efficiency reasons. - cur_pos += cp_byte_length; // Now move by one codepoint. - continue; - } - absl::string_view normalized_codepoint = - GetNormalizedString(encoded_data); - // Copy the previous unchanged text if any. - copy_unchanged_input_to_output(cur_pos); - - // Output the normalized codepoint text. - absl::StrAppend(output_normalized_text, normalized_codepoint); - if constexpr (kGetOffsets) { - // Every byte of the normalized string should be map to the same start - // position of the current codepoint in the original `input_text`. - for (int i = 0; i < normalized_codepoint.size(); ++i) { - output_normalized_offset_mapping->push_back(cur_pos); - } - } - // Move by one codepoint. - cur_pos += cp_byte_length; - // Mark the next position to copy over. - last_pos_to_copy_over = cur_pos; - } - if (last_pos_to_copy_over == 0) { - // This means that the normalized string would be the same as the input. - *is_output_identical_as_input = true; - return; - } - *is_output_identical_as_input = false; - // Copy the remaining unchanged text if any. - copy_unchanged_input_to_output(input_text.size()); - // Push one more mapping from end_of_normalized to end_of_original. - if constexpr (kGetOffsets) { - output_normalized_offset_mapping->push_back(input_text.size()); - } - } - - private: - // Use the public Create() method. - FastBertNormalizer() {} - - // Returns true if the normalized string is different from the codepoint (from - // the encoded `data`). If `data`==0, it means the normalized string is the - // same; in that case, this function returns false correctly. - static bool IsNormalizedStringDifferent(int data) { - return static_cast(data & - text_norm::kIsNormalizedStringDifferentMask); - } - - // Calls this only when IsNormalizedStringDifferent(data) returns true. - absl::string_view GetNormalizedString(int data) const { - const int len = data & text_norm::kNormalizedStringLengthMask; - if (!len) { - return ""; - } - const int offset = (data & text_norm::kNormalizedStringOffsetMask) >> - text_norm::kBitsToEncodeUtf8LengthOfNormalizedString; - return absl::string_view(normalized_string_pool_ + offset, len); - } - - // Looks up the character in format of utf8 string format and returns the - // associated data. If not found, returns 0. Note that 0 also means the - // normalized string is the same as the codepoint itself (refer to - // `kIsNormalizedStringDifferentMask`). - int LookupData(absl::string_view utf8_view) const { - return LookupData(utf8_view.data(), utf8_view.size()); - } - - // The actual implementation of LookupData(). 'utf8_view_ptr' and 'size' - // should point to the utf8 view of a codepoint. Performance-critical. - // Implicitly inline. - int LookupData(const char* utf8_view_ptr, int size) const { - // Darts_clone trie cannot encode the empty input string, so we store and - // return this value separately. - if (size == 0 || *utf8_view_ptr == '\0') return data_for_codepoint_zero_; - auto cursor = trie_->CreateTraversalCursorPointToRoot(); - if (!trie_->TryTraverseSeveralSteps( - cursor, absl::string_view(utf8_view_ptr, size))) { - return 0; - } - int data; - if (!trie_->TryGetData(cursor, data)) { - return 0; - } - return data; - } - - // Provides traversal/data-accessing methods on the trie. It has a pointer - // that points to 'trie_array_'. - std::unique_ptr trie_; - - // The encoded data for the special codepoint '\0'. Darts_clone trie cannot - // encode the empty string, so we store this value separately. - int data_for_codepoint_zero_; - - // The string pool of normalized strings. Each normalized string is a - // substring denoted by (offset and length). - const char* normalized_string_pool_; -}; - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/fast_bert_normalizer.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_H_ diff --git a/tensorflow_text/core/kernels/fast_bert_normalizer_kernel_template.h b/tensorflow_text/core/kernels/fast_bert_normalizer_kernel_template.h index ae795b53b..d34da0e96 100644 --- a/tensorflow_text/core/kernels/fast_bert_normalizer_kernel_template.h +++ b/tensorflow_text/core/kernels/fast_bert_normalizer_kernel_template.h @@ -15,246 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_KERNEL_TEMPLATE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_KERNEL_TEMPLATE_H_ -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/lite/kernels/shim/op_kernel.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow_text/core/kernels/fast_bert_normalizer.h" +#include "tensorflow/core/kernels/text/fast_bert_normalizer_kernel_template.h" -namespace tensorflow { -namespace text { - -// See `kDoc` data member for the documentation on this op kernel. -// -// This template class can be instantiated into a kernel for either TF or -// TFLite. See go/tfshim for more info on how this works. -template -class FastBertNormalizeOp - : public tflite::shim::OpKernelShim { - private: - enum Inputs { kInputValues = 0, kFastBertNormalizerModel }; - enum Outputs { - kOutputValues = 0, - kOutputOffsets, - kOutputRowSplitsOfOffsets, - }; - - using Shape = tflite::shim::Shape; - using - typename tflite::shim::OpKernelShim::InitContext; - using typename tflite::shim::OpKernelShim::InvokeContext; - using typename tflite::shim::OpKernelShim::ShapeInferenceContext; - - static const char kGetOffsetsAttr[]; - - // The real work of the invoke operation. - template - absl::Status InvokeRealWork(InvokeContext* context); - - bool get_offsets_; - - public: - FastBertNormalizeOp() = default; - static constexpr char kOpName[] = "FastBertNormalize"; - static constexpr char kDoc[] = R"doc( - Normalizes texts. - - It returns the normalized texts and the relative offsets from the normalized - text to the original text. - - Args: - * input_values: 1D Tensor of strings to normalize. - * fast_bert_normalizer_model: Buffer tensor for the FastBertNormalizerModel - flatbuffer. - - Returns: - * output_values: 1D tensor containing the normalized text for all input - strings. The shape is the same as the input strings. - * output_offsets: 1D tensor containing the offset mapping from the - normalized text to the original text. A 2D RaggedTensor can be constructed - from this and output_row_splits. For example, if the input is - `input_values[i1...iN]` with `N` strings, the constructed 2D RaggedTensor - `offsets[i1...iN, k]` is the byte offset in `input_values[i1...iN]` for - the `kth` byte in `output_values[i1...iN]` after normalization. Note that - `offsets[i1...iN, ...]` also covers the position following the last byte - in the normalized `output_values[i1...iN]`, so that we know the byte - offset position in `input_values[i1...iN]` that corresponds to the end of - `output_values[i1...iN]`. - - - * output_row_splits: 1D int tensor with the row splits that allow us to - build RaggedTensors from output_offsets. - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Attrs(); - - // Input tensors declaration (syntax: - // https://www.tensorflow.org/guide/create_op) - static std::vector Inputs(); - - // Output tensors declaration (syntax: - // https://www.tensorflow.org/guide/create_op) - static std::vector Outputs(); - - // Initializes the op - absl::Status Init(InitContext* context); - - // Runs the operation - absl::Status Invoke(InvokeContext* context); - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* c); -}; - -////////////////////////// Implementation - -template -const char FastBertNormalizeOp::kGetOffsetsAttr[] = - "get_offsets"; - -template -std::vector FastBertNormalizeOp::Attrs() { - return { - absl::StrCat(kGetOffsetsAttr, ": bool = false"), - }; -} - -template -std::vector FastBertNormalizeOp::Inputs() { - return {"input_values: string", "fast_bert_normalizer_model: uint8"}; -} - -template -std::vector FastBertNormalizeOp::Outputs() { - return {"output_values: string", "output_offsets: int64", - "output_row_splits: int64"}; -} - -template -absl::Status FastBertNormalizeOp::Init(InitContext* context) { - SH_RETURN_IF_ERROR( - context->GetAttr(kGetOffsetsAttr, &get_offsets_)); - return absl::OkStatus(); -} - -template -absl::Status FastBertNormalizeOp::Invoke(InvokeContext* context) { - if (get_offsets_) { - return InvokeRealWork(context); - } else { - return InvokeRealWork(context); - } -} - -template -template -absl::Status FastBertNormalizeOp::InvokeRealWork(InvokeContext* context) { - SH_ASSIGN_OR_RETURN(const auto input_values, context->GetInput(kInputValues)); - const auto& values_vec = input_values->template As(); - - SH_ASSIGN_OR_RETURN(const auto fast_bert_normalizer_model, - context->GetInput(kFastBertNormalizerModel)); - // OK to create on every call because FastBertNormalizer is a lightweight, - // memory-mapped wrapper on `fast_bert_normalizer_model` tensor, and thus - // Create() is very cheap. - auto text_normalizer = FastBertNormalizer::Create( - fast_bert_normalizer_model->template Data().data()); - SH_RETURN_IF_ERROR(text_normalizer.status()); - - SH_ASSIGN_OR_RETURN( - auto output_values, - context->GetOutput(kOutputValues, Shape(input_values->Shape()))); - auto output_values_vec = output_values->template As(); - std::vector offsets; - std::vector row_splits; - - if constexpr (kGetOffsets) { - row_splits.push_back(0); - } - - // Iterate through all the values and normalize them. - for (int i = 0; i < values_vec.Dim(0); ++i) { - // Normalize and record the offset locations. - std::string normalized_string; - bool is_normalized_string_identical; - const int original_size = offsets.size(); - - text_normalizer->template NormalizeText( - values_vec(i), &is_normalized_string_identical, &normalized_string, - &offsets); - if (is_normalized_string_identical) { - // When the input string is not changed after normalization, - // `normalized_string` is empty and `offsets` is not changed by - // the above function. So here we construct the corresponding result and - // append to the final output. - output_values_vec(i) = values_vec(i); // The normalized text. - if constexpr (kGetOffsets) { - // The offset mapping will be the identy mapping. - for (int j = 0; j < values_vec(i).size(); ++j) { - offsets.push_back(j); - } - // The mapping from the end of the output to the end of the input. - offsets.push_back(values_vec(i).size()); - } - } else { - output_values_vec(i) = normalized_string; - } - - if constexpr (kGetOffsets) { - // Record the row splits. - const int delta_size = offsets.size() - original_size; - row_splits.push_back(delta_size + row_splits.back()); - } - } - - if constexpr (kGetOffsets) { - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - offsets, kOutputOffsets, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - row_splits, kOutputRowSplitsOfOffsets, context)); - } else { - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - offsets, kOutputOffsets, context)); - row_splits.resize(1+values_vec.Dim(0)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - row_splits, kOutputRowSplitsOfOffsets, context)); - } - return absl::OkStatus(); -} - -template -absl::Status FastBertNormalizeOp::ShapeInference(ShapeInferenceContext* c) { - using tflite::shim::Shape; - SH_ASSIGN_OR_RETURN(const Shape input_values_shape, - c->GetInputShape(kInputValues)); - SH_ASSIGN_OR_RETURN(const auto fast_bert_normalizer_model_shape, - c->GetInputShape(kFastBertNormalizerModel)); - const auto rank_1_shape = Shape({Shape::kUnknownDim}); - if (!input_values_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Input values shape must be rank 1: ", input_values_shape.ToString())); - } - if (!fast_bert_normalizer_model_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Fast BERT normalizer model shape must be rank 1: ", - fast_bert_normalizer_model_shape.ToString())); - } - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputValues, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputOffsets, rank_1_shape)); - // row splits size - const int num_splits = Shape::AddDims(1, input_values_shape.Dim(0)); - SH_RETURN_IF_ERROR( - c->SetOutputShape(kOutputRowSplitsOfOffsets, Shape({num_splits}))); - - return absl::OkStatus(); -} - -} // namespace text -} // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_KERNEL_TEMPLATE_H_ diff --git a/tensorflow_text/core/kernels/fast_bert_normalizer_model.fbs b/tensorflow_text/core/kernels/fast_bert_normalizer_model.fbs deleted file mode 100644 index 75b57d49a..000000000 --- a/tensorflow_text/core/kernels/fast_bert_normalizer_model.fbs +++ /dev/null @@ -1,20 +0,0 @@ -namespace tensorflow.text; - -table FastBertNormalizerModel { - // If true, a preprocessing step is added to lowercase the text, apply NFD - // normalization, and strip accents characters. - lower_case_nfd_strip_accents: bool; - - // The trie data, in the format of darts_clone trie, for input normalization. - trie_array: [uint32]; - - // The encoded data for the special codepoint '\0'. Darts_clone trie cannot - // encode the empty string, so we store this value separately. - data_for_codepoint_zero: int32; - - // The string pool of normalized strings. Each normalized string is a - // substring denoted by (offset and length). - normalized_string_pool: [ubyte]; -} - -root_type FastBertNormalizerModel; diff --git a/tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.cc b/tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.cc deleted file mode 100644 index 808d2a09c..000000000 --- a/tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.cc +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.h" - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/memory/memory.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "icu4c/source/common/unicode/errorcode.h" -#include "icu4c/source/common/unicode/normalizer2.h" -#include "icu4c/source/common/unicode/utf.h" -#include "icu4c/source/common/unicode/utf8.h" -#include "re2/re2.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow_text/core/kernels/darts_clone_trie_builder.h" -#include "tensorflow_text/core/kernels/fast_bert_normalizer.h" -#include "tensorflow_text/core/kernels/fast_bert_normalizer_model_generated.h" - -namespace tensorflow { -namespace text { -namespace { -// Adapted from CaseFoldUTF8Op::Compute() in -// https://github.com/tensorflow/text/blob/master/tensorflow_text/core/kernels/normalize_kernels.cc. -absl::StatusOr case_fold_utf8(absl::string_view input) { - std::string output_text; - icu::ErrorCode icu_error; - const icu::Normalizer2* nfkc_cf = - icu::Normalizer2::getNFKCCasefoldInstance(icu_error); - if (!icu_error.isSuccess()) { - return absl::InternalError( - "Could not retrieve ICU NFKC_CaseFold normalizer"); - } - icu::StringByteSink byte_sink(&output_text); - nfkc_cf->normalizeUTF8(0, icu::StringPiece(input.data(), input.size()), - byte_sink, nullptr, icu_error); - if (!icu_error.isSuccess()) { - return absl::InternalError( - absl::StrCat("Could not normalize input string: ", input)); - } - return output_text; -} - -// Adapted from NormalizeUTF8Op::Compute() in -// https://github.com/tensorflow/text/blob/master/tensorflow_text/core/kernels/normalize_kernels.cc. -absl::StatusOr normalize_utf8_nfd(absl::string_view input) { - icu::ErrorCode icu_error; - const icu::Normalizer2* normalizer = - icu::Normalizer2::getNFDInstance(icu_error); - if (!icu_error.isSuccess()) { - return absl::InternalError(absl::StrCat( - icu_error.errorName(), ": Could not retrieve ICU NFD normalizer")); - } - std::string output_text; - icu::StringByteSink byte_sink(&output_text); - normalizer->normalizeUTF8(0, icu::StringPiece(input.data(), input.size()), - byte_sink, nullptr, icu_error); - if (!icu_error.isSuccess()) { - return absl::InternalError(absl::StrCat( - icu_error.errorName(), ": Could not normalize input string: ", input)); - } - return output_text; -} - -// Returns all valid Unicode codepoints. -std::vector AllValidUnicodeCodePoints() { - std::vector ret; - // The maximum codepoint in Unicode is 0x0010FFFF. - for (char32_t cp = 0; cp <= 0x0010FFFF; ++cp) { - if (!U_IS_UNICODE_CHAR(cp)) { - continue; - } - ret.push_back(cp); - } - return ret; -} - -// Calls the original methods as in BertTokenizer (e.g., icu lib, etc.) to -// normalize the input. Based on -// https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/bert_tokenizer.py. -absl::StatusOr OriginalNormalizeText( - absl::string_view input, bool lower_case_nfd_strip_accents) { - static const RE2* const kMnRegex = new RE2("\\p{Mn}"); - static const RE2* const kControlRegex = new RE2("\\p{Cc}|\\p{Cf}"); - std::string output_text = std::string(input); - // Lowercase and strip accents (if option is set) - if (lower_case_nfd_strip_accents) { - SH_ASSIGN_OR_RETURN(output_text, case_fold_utf8(output_text)); - SH_ASSIGN_OR_RETURN(output_text, normalize_utf8_nfd(output_text)); - RE2::GlobalReplace(&output_text, *kMnRegex, ""); - } - - // Replace control characters with spaces. - RE2::GlobalReplace(&output_text, *kControlRegex, " "); - - return output_text; -} -} // namespace - -absl::StatusOr BuildFastBertNormalizerModelAndExportToFlatBuffer( - bool lower_case_nfd_strip_accents) { - const auto& text_normalizer = - FastBertNormalizerFactory::GetInstance(lower_case_nfd_strip_accents); - flatbuffers::FlatBufferBuilder builder; - const auto array = builder.CreateVector(text_normalizer.GetTrieData()); - const auto mapped_string_pool = builder.CreateVector( - std::vector(text_normalizer.GetMappedValuePool().begin(), - text_normalizer.GetMappedValuePool().end())); - auto text_normalizer_model = CreateFastBertNormalizerModel( - builder, lower_case_nfd_strip_accents, array, - text_normalizer.GetDataForCodepointZero(), mapped_string_pool); - builder.Finish(text_normalizer_model); - return std::string(reinterpret_cast(builder.GetBufferPointer()), - builder.GetSize()); -} - -/*static*/ absl::Status FastBertNormalizerFactory::BuildFastBertNormalizer( - bool lower_case_nfd_strip_accents, std::vector& trie_data, - int& data_for_codepoint_zero, std::string& mapped_value_string_pool) { - // Prepare the string keys and the encoded values. - std::vector keys; - std::vector values; - mapped_value_string_pool = ""; - data_for_codepoint_zero = 0; - // Memorize and reuse normalized strings. - absl::flat_hash_map norm_string_to_pool_offset; - - for (const auto cp : AllValidUnicodeCodePoints()) { - // Get the utf8 view of the codepoint. - char buf[4]; - int len = 0; - U8_APPEND_UNSAFE(buf, len, cp); - const absl::string_view cp_view(buf, len); - // Normalize. - SH_ASSIGN_OR_RETURN( - auto cp_norm, - OriginalNormalizeText(cp_view, lower_case_nfd_strip_accents)); - int data = 0; - if (cp_norm != cp_view) { - // The mapped value is different from the input. - data |= text_norm::kIsNormalizedStringDifferentMask; - // Encode the mapped value into `data`. - if (!cp_norm.empty()) { - const auto itr = norm_string_to_pool_offset.find(cp_norm); - int current_offset = 0; - if (itr == norm_string_to_pool_offset.end()) { - if (cp_norm.size() > - text_norm::kMaximumUtf8LengthOfNormalizedString) { - LOG(ERROR) << "The length of mapped value exceeds the maximum " - "supported. Codepoint: " - << uint32_t{cp} - << ". Mapped value length: " << cp_norm.size() - << ". Maximum supported length: " - << text_norm::kMaximumUtf8LengthOfNormalizedString; - } - current_offset = mapped_value_string_pool.size(); - if (current_offset > text_norm::kMaximumOffsetOfNormalizedString) { - LOG(ERROR) << "The offset of mapped value exceeds the maximum " - "supported. Codepoint: " - << uint32_t{cp} - << ". Mapped value offset: " << current_offset - << ". Maximum supported length: " - << text_norm::kMaximumOffsetOfNormalizedString; - } - norm_string_to_pool_offset[cp_norm] = current_offset; - absl::StrAppend(&mapped_value_string_pool, cp_norm); - } else { - current_offset = norm_string_to_pool_offset[cp_norm]; - } - data |= cp_norm.size(); - data |= (current_offset - << text_norm::kBitsToEncodeUtf8LengthOfNormalizedString); - } - } - // Store the encoded data. - if (cp == 0) { - data_for_codepoint_zero = data; - // Skip encoding it into the trie since Darts_clone cannot encode the - // empty string. - continue; - } - if (data == 0) { - // Data is not set when normalizing the codepoint doesn't change it. These - // characters aren't encoded to save space. - continue; - } - // Key is the utf8 view; value is the encoded data. - keys.emplace_back(buf, len); - values.push_back(data); - } - // Build the trie. - SH_ASSIGN_OR_RETURN(trie_data, trie_utils::BuildDartsCloneTrie(keys, values)); - LOG(INFO) << "CharacterSet built (lower_case_nfd_strip_accents=" - << lower_case_nfd_strip_accents - << "). Trie data size (int32): " << trie_data.size() - << ". Normalized string pool size (byte): " - << mapped_value_string_pool.size(); - return absl::OkStatus(); -} - -FastBertNormalizerFactory::FastBertNormalizerFactory( - bool lower_case_nfd_strip_accents) { - auto status = - BuildFastBertNormalizer(lower_case_nfd_strip_accents, trie_data_, - data_for_codepoint_zero_, mapped_value_pool_); - if (!status.ok()) { - // Should never happen since the same code must have passed the unit tests. - LOG(ERROR) << "Unexpected error. Failed to build the data for " - "FastBertNormalizer. Error message: " - << status.message(); - return; - } - auto char_set_recognizer_mapper = FastBertNormalizer::Create( - trie_data_.data(), data_for_codepoint_zero_, mapped_value_pool_.data()); - if (!char_set_recognizer_mapper.ok()) { - // Should never happen since the same code must have passed the unit tests. - LOG(ERROR) << "Unexpected error: Failed to initialize " - "FastBertNormalizer from the data."; - return; - } - char_set_normalizer_ = std::make_unique( - *std::move(char_set_recognizer_mapper)); -} -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.h b/tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.h index a2e57b7c5..f8b8d3c8f 100644 --- a/tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.h +++ b/tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.h @@ -15,86 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_MODEL_BUILDER_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_MODEL_BUILDER_H_ -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "tensorflow_text/core/kernels/fast_bert_normalizer.h" - -namespace tensorflow { -namespace text { - -// Builds a FastBertNormalizer model in flatbuffer format. -// -// Args: -// * lower_case_nfd_strip_accents: If true, a preprocessing step is added to -// lowercase the text, apply NFD normalization, and strip accents characters. -// -// Returns: -// The bytes of the flatbuffer that stores the model. -absl::StatusOr BuildFastBertNormalizerModelAndExportToFlatBuffer( - bool lower_case_nfd_strip_accents); - -/// A singleton class to initialize FastBertNormalizer and also to -/// own the data for it. -class FastBertNormalizerFactory { - public: - // Returns the singleton instance. - // - // Args: - // lower_case_nfd_strip_accents: bool - // - If true, it first lowercases the text, applies NFD normalization, - // strips accents characters, and then replaces control characters with - // whitespaces. - // - If false, it only replaces control characters with whitespaces. - static const FastBertNormalizerFactory& GetInstance( - bool lower_case_nfd_strip_accents) { - if (lower_case_nfd_strip_accents) { - return GetInstanceLowerCase(); - } else { - return GetInstanceNoLowerCase(); - } - } - - const FastBertNormalizer* GetNormalizer() const { - return char_set_normalizer_.get(); - } - - const std::vector& GetTrieData() const { return trie_data_; } - - int GetDataForCodepointZero() const { return data_for_codepoint_zero_; } - - absl::string_view GetMappedValuePool() const { return mapped_value_pool_; } - - private: - FastBertNormalizerFactory(bool lower_case_nfd_strip_accents); - - // Returns a singleton instance with lower_case_nfd_strip_accents = false. - static const FastBertNormalizerFactory& GetInstanceNoLowerCase() { - static const FastBertNormalizerFactory* const kInstance = - new FastBertNormalizerFactory(false); - return *kInstance; - } - - // Returns a singleton instance with lower_case_nfd_strip_accents = true. - static const FastBertNormalizerFactory& GetInstanceLowerCase() { - static const FastBertNormalizerFactory* const kInstance = - new FastBertNormalizerFactory(true); - return *kInstance; - } - - // Returns the data to build a FastBertNormalizer. - static absl::Status BuildFastBertNormalizer( - bool lower_case_nfd_strip_accents, std::vector& trie_data, - int& data_for_codepoint_zero, std::string& mapped_value_string_pool); - - std::vector trie_data_; - int data_for_codepoint_zero_ = 0; - std::string mapped_value_pool_ = ""; - std::unique_ptr char_set_normalizer_ = nullptr; -}; - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/fast_bert_normalizer_model_builder.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_MODEL_BUILDER_H_ diff --git a/tensorflow_text/core/kernels/fast_bert_normalizer_tf_kernel.cc b/tensorflow_text/core/kernels/fast_bert_normalizer_model_generated.h similarity index 63% rename from tensorflow_text/core/kernels/fast_bert_normalizer_tf_kernel.cc rename to tensorflow_text/core/kernels/fast_bert_normalizer_model_generated.h index 7faba4703..06160f4ec 100644 --- a/tensorflow_text/core/kernels/fast_bert_normalizer_tf_kernel.cc +++ b/tensorflow_text/core/kernels/fast_bert_normalizer_model_generated.h @@ -12,16 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "tensorflow_text/core/kernels/fast_bert_normalizer_tf_kernel.h" +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_MODEL_GENERATED_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_MODEL_GENERATED_H_ -#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/text/fast_bert_normalizer_model_generated.h" -namespace tensorflow { -namespace text { - -REGISTER_KERNEL_BUILDER( - Name(FastBertNormalizeOpKernel::OpName()).Device(tensorflow::DEVICE_CPU), - FastBertNormalizeOpKernel); - -} // namespace text -} // namespace tensorflow +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_MODEL_GENERATED_H_ diff --git a/tensorflow_text/core/kernels/fast_bert_normalizer_test.cc b/tensorflow_text/core/kernels/fast_bert_normalizer_test.cc deleted file mode 100644 index 73b4a4c51..000000000 --- a/tensorflow_text/core/kernels/fast_bert_normalizer_test.cc +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/fast_bert_normalizer.h" - -#include - -#include -#include -#include "tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.h" - -namespace tensorflow { -namespace text { -namespace { - -template -std::string ListToString(const std::vector& list) { - return absl::StrCat("[", absl::StrJoin(list, ", "), "]"); -} - -// Testing spec struct for parameterized tests. -struct Spec { - friend std::ostream& operator<<(std::ostream& os, const Spec& s) { - return os << "input: " << s.input << ", " - << "lower_case_nfd_strip_accents:" - << s.lower_case_nfd_strip_accents << ", " - << "expected_output:" << s.expected_output << ", " - << "expected_offset_mapping:" - << ListToString(s.expected_offset_mapping) << std::endl; - } - - std::string input; - bool lower_case_nfd_strip_accents = false; - std::string expected_output; - std::vector expected_offset_mapping; -}; - -// Parameterized tests specs for FastBertNormalizer. -const std::vector& GetTestSpecs() { - static const std::vector& v = *new std::vector{ - // Test Suite 1: No lower case. - // Test 0: Empty input. - { - .input = "", - .lower_case_nfd_strip_accents = false, - .expected_output = "", - .expected_offset_mapping = {0}, - }, - // Test 1: All ascii, digit, and normal letters. - { - .input = "Test #1.", - .lower_case_nfd_strip_accents = false, - .expected_output = "Test #1.", - .expected_offset_mapping = {0, 1, 2, 3, 4, 5, 6, 7, 8}, - }, - // Test 2: Multi-byte letters. - // "\xC3\x80" is U+00C0 "Latin Capital Letter A with Grave". - // "\x41\xCC\x80" is the decomposition of U+00C0 "Latin Capital Letter A - // with Grave". - { - .input = "Test\xC3\x80\x41\xCC\x80 #1.", - .lower_case_nfd_strip_accents = false, - .expected_output = "Test\xC3\x80\x41\xCC\x80 #1.", - .expected_offset_mapping = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13}, - }, - // Test 3: Control chars normalized into whitespaces. - { - .input = "Te\x11st #1.", - .lower_case_nfd_strip_accents = false, - .expected_output = "Te st #1.", - .expected_offset_mapping = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, - }, - // Test 4: Tabs and newlines normalized into whitespaces. - { - .input = "Test \t\n#1.", - .lower_case_nfd_strip_accents = false, - .expected_output = "Test #1.", - .expected_offset_mapping = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - }, - // Test Suite 2: Lower case. - // Test 5: Empty input. - { - .input = "", - .lower_case_nfd_strip_accents = true, - .expected_output = "", - .expected_offset_mapping = {0}, - }, - // Test 6: All ascii, digit, and normal letters. - { - .input = "Test #1.", - .lower_case_nfd_strip_accents = true, - .expected_output = "test #1.", - .expected_offset_mapping = {0, 1, 2, 3, 4, 5, 6, 7, 8}, - }, - // Test 7: Multi-byte letters. - // "\xC3\x80" is U+00C0 "Latin Capital Letter A with Grave", which is - // normalized to "a". "\x41\xCC\x80" is the decomposition of U+00C0 "Latin - // Capital Letter A with Grave", which is normalized to "a". - { - .input = "Test\xC3\x80\x41\xCC\x80 #1.", - .lower_case_nfd_strip_accents = true, - .expected_output = "testaa #1.", - .expected_offset_mapping = {0, 1, 2, 3, 4, 6, 9, 10, 11, 12, 13}, - }, - // Test 8: Control chars normalized into whitespaces. - { - .input = "Te\x11st #1.", - .lower_case_nfd_strip_accents = true, - .expected_output = "te st #1.", - .expected_offset_mapping = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, - }, - // Test 9: Tabs and newlines normalized into whitespaces. - { - .input = "Test \t\n#1.", - .lower_case_nfd_strip_accents = true, - .expected_output = "test #1.", - .expected_offset_mapping = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - }, - // Test 10: Multibytes string normalized into multibytes string. - // "\xC2\xBC" (2 bytes) is normalized into "1\xE2\x81\x84""4" (5 bytes). - { - .input = "a\xC2\xBC", - .lower_case_nfd_strip_accents = true, - .expected_output = "a1\xE2\x81\x84" - "4", - .expected_offset_mapping = {0, 1, 1, 1, 1, 1, 3}, - }, - // Test 11: Multibytes string normalized into multibytes string. - // "\xC7\xB2" (2 bytes) is normalized into "dz" (2 bytes). - { - .input = "a\xC7\xB2", - .lower_case_nfd_strip_accents = true, - .expected_output = "adz", - .expected_offset_mapping = {0, 1, 1, 3}, - }, - // Test 12: Multibytes string normalized into multibytes string. - // "\xCE\xB9" (2 bytes) is normalized into "\xCE\xB7" (2 bytes). - { - .input = "a\xCE\x89", - .lower_case_nfd_strip_accents = true, - .expected_output = "a\xCE\xB7", - .expected_offset_mapping = {0, 1, 1, 3}, - }, - // Test 13: Invalid UTF8 input. lower_case_nfd_strip_accents = false. - { - .input = "a\x80 \xFF \xF8 a\xE0\x61 \xF3\x9C\x9D", - .lower_case_nfd_strip_accents = false, - .expected_output = "a\x80 \xFF \xF8 a\xE0\x61 \xF3\x9C\x9D", - .expected_offset_mapping = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14}, - }, - // Test 14: Invalid UTF8 input. lower_case_nfd_strip_accents = true. - { - .input = "a\x80 \xFF \xF8 a\xE0\x61 \xF3\x9C\x9D", - .lower_case_nfd_strip_accents = true, - .expected_output = "a\x80 \xFF \xF8 a\xE0\x61 \xF3\x9C\x9D", - .expected_offset_mapping = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14}, - }, - }; - return v; -} - -using TestNormalization = testing::TestWithParam; - -TEST_P(TestNormalization, TestGetOffsets) { - const auto spec = GetParam(); - const auto fast_bert_normalizer = - FastBertNormalizerFactory::GetInstance(spec.lower_case_nfd_strip_accents) - .GetNormalizer(); - - std::string output_normalized_text = "Something existing"; - std::vector output_normalized_offset_mapping; - bool is_normalized_identical; - fast_bert_normalizer->NormalizeText( - spec.input, &is_normalized_identical, &output_normalized_text, - &output_normalized_offset_mapping); - if (is_normalized_identical) { - ASSERT_THAT(output_normalized_text, ""); - ASSERT_THAT(spec.input, spec.expected_output); - ASSERT_THAT(output_normalized_offset_mapping, testing::ElementsAre()); - } else { - ASSERT_THAT(output_normalized_text, spec.expected_output); - ASSERT_THAT(output_normalized_offset_mapping, spec.expected_offset_mapping); - } -} - -TEST_P(TestNormalization, TestNoGetOffsets) { - const auto spec = GetParam(); - const auto fast_bert_normalizer = - FastBertNormalizerFactory::GetInstance(spec.lower_case_nfd_strip_accents) - .GetNormalizer(); - - std::string output_normalized_text; - std::vector output_normalized_offset_mapping; - bool is_normalized_identical; - fast_bert_normalizer->NormalizeText( - spec.input, &is_normalized_identical, &output_normalized_text, - /*output_normalized_offset_mapping=*/nullptr); - if (is_normalized_identical) { - ASSERT_THAT(spec.input, spec.expected_output); - ASSERT_THAT(output_normalized_text, ""); - } else { - ASSERT_THAT(output_normalized_text, spec.expected_output); - } -} - -INSTANTIATE_TEST_SUITE_P(FastBertNormalizerTest, TestNormalization, - testing::ValuesIn(GetTestSpecs())); -} // namespace -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/fast_bert_normalizer_tf_kernel.h b/tensorflow_text/core/kernels/fast_bert_normalizer_tf_kernel.h index 2c1547115..c3263a707 100644 --- a/tensorflow_text/core/kernels/fast_bert_normalizer_tf_kernel.h +++ b/tensorflow_text/core/kernels/fast_bert_normalizer_tf_kernel.h @@ -15,19 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_TF_KERNEL_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_TF_KERNEL_H_ -#include "tensorflow/lite/kernels/shim/tf_op_shim.h" -#include "tensorflow_text/core/kernels/fast_bert_normalizer_kernel_template.h" - -namespace tensorflow { -namespace text { - -class FastBertNormalizeOpKernel - : public tflite::shim::TfOpKernel { - public: - using TfOpKernel::TfOpKernel; -}; - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/fast_bert_normalizer_tf_kernel.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_TF_KERNEL_H_ diff --git a/tensorflow_text/core/kernels/fast_bert_normalizer_tflite.cc b/tensorflow_text/core/kernels/fast_bert_normalizer_tflite.cc deleted file mode 100644 index d61202bdc..000000000 --- a/tensorflow_text/core/kernels/fast_bert_normalizer_tflite.cc +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/fast_bert_normalizer_tflite.h" - -#include "tensorflow/lite/kernels/shim/tflite_op_shim.h" -#include "tensorflow_text/core/kernels/fast_bert_normalizer_kernel_template.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -using FastBertNormalizeOpKernel = - tflite::shim::TfLiteOpKernel; - -extern "C" void AddFastBertNormalize(tflite::MutableOpResolver* resolver) { - FastBertNormalizeOpKernel::Add(resolver); -} - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow_text/core/kernels/fast_bert_normalizer_tflite.h b/tensorflow_text/core/kernels/fast_bert_normalizer_tflite.h index d6e2de641..503277ab4 100644 --- a/tensorflow_text/core/kernels/fast_bert_normalizer_tflite.h +++ b/tensorflow_text/core/kernels/fast_bert_normalizer_tflite.h @@ -12,21 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_TENSORFLOW_TEXT_GOOGLE_CORE_KERNELS_FAST_BERT_NORMALIZER_TFLITE_H_ -#define THIRD_PARTY_TENSORFLOW_TEXT_GOOGLE_CORE_KERNELS_FAST_BERT_NORMALIZER_TFLITE_H_ +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_TFLITE_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_TFLITE_H_ -#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/core/kernels/text/fast_bert_normalizer_tflite.h" -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -extern "C" void AddFastBertNormalize(::tflite::MutableOpResolver* resolver); - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite - -#endif // THIRD_PARTY_TENSORFLOW_TEXT_GOOGLE_CORE_KERNELS_FAST_BERT_NORMALIZER_TFLITE_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_BERT_NORMALIZER_TFLITE_H_ diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc deleted file mode 100644 index 88feee121..000000000 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc +++ /dev/null @@ -1,748 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h" - -#include - -#include "absl/base/attributes.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/match.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "icu4c/source/common/unicode/uchar.h" -#include "icu4c/source/common/unicode/utf8.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_utils.h" - -namespace tensorflow { -namespace text { -namespace { - -template -int GetCurrentOutputSize(std::vector* output_pieces, - std::vector* output_ids) { - if constexpr (kGetPieces) { - return output_pieces->size(); - } else { - return output_ids->size(); - } -} - -} // namespace - -/*static*/ absl::StatusOr -FastWordpieceTokenizer::Create(const void* config_flatbuffer) { - FastWordpieceTokenizer tokenizer; - // `GetFastWordpieceTokenizerConfig()` is autogenerated by flatbuffer. - tokenizer.config_ = GetFastWordpieceTokenizerConfig(config_flatbuffer); - auto trie_or = trie_utils::DartsCloneTrieWrapper::Create( - tokenizer.config_->trie_array()->data()); - if (!trie_or.ok()) { - return absl::InvalidArgumentError( - "Failed to create DartsCloneTrieWrapper from " - "FastWordpieceTokenizerConfig.trie_array."); - } - tokenizer.trie_ = - std::make_unique(*std::move(trie_or)); - return std::move(tokenizer); -} - -void FastWordpieceTokenizer::Tokenize(absl::string_view input, - std::vector* output_pieces, - std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets, - int input_word_offset_in_text, - bool* error) const { - if (config_->end_to_end()) { - TokenizeTextImpl(input, output_pieces, output_ids, - output_start_offsets, - output_end_offsets, error); - } else { - TokenizeSingleWordImpl( - input, input_word_offset_in_text, output_pieces, output_ids, - output_start_offsets, output_end_offsets); - } -} - -void FastWordpieceTokenizer::Tokenize(absl::string_view input, - std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets, - int input_word_offset_in_text) const { - if (config_->end_to_end()) { - TokenizeTextImpl( - input, /*output_pieces=*/nullptr, output_ids, output_start_offsets, - output_end_offsets, /*error=*/nullptr); - } else { - TokenizeSingleWordImpl( - input, input_word_offset_in_text, /*output_pieces=*/nullptr, output_ids, - output_start_offsets, output_end_offsets); - } -} - -void FastWordpieceTokenizer::Tokenize(absl::string_view input, - std::vector* output_ids, - int input_word_offset_in_text) const { - if (config_->end_to_end()) { - TokenizeTextImpl( - input, /*output_pieces=*/nullptr, output_ids, - /*output_start_offsets=*/nullptr, - /*output_end_offsets=*/nullptr, /*error=*/nullptr); - } else { - TokenizeSingleWordImpl( - input, input_word_offset_in_text, /*output_pieces=*/nullptr, output_ids, - /*output_start_offsets=*/nullptr, - /*output_end_offsets=*/nullptr); - } -} - -absl::StatusOr> -FastWordpieceTokenizer::DetokenizeToTokens( - const absl::Span input) const { - std::vector subwords; - std::vector output_tokens; - if (!config_->support_detokenization()) { - return absl::FailedPreconditionError( - "Detokenize function is only enabled when support_detokenization is " - "true in the config flatbuffer. Please rebuild the model flatbuffer " - "by setting support_detokenization=true."); - } - for (int id : input) { - auto vocab = config_->vocab_array()->Get(id); - auto is_suffix = config_->vocab_is_suffix_array()->Get(id); - if (!subwords.empty() && !is_suffix) { - // When current subword is not a suffix token, it marks the start of a new - // word. We concatenate the subwords that compose the previous word and - // add it to the return list. - output_tokens.emplace_back(absl::StrJoin(subwords, "")); - subwords.clear(); - } - // Special case: when a suffix token e.g. "##a" appears at the start of the - // input ids, we preserve the suffix_indicator. - if (subwords.empty() && is_suffix) { - subwords.emplace_back(config_->suffix_indicator()->string_view()); - } - subwords.emplace_back(vocab->string_view()); - } - if (!subwords.empty()) { - output_tokens.emplace_back(absl::StrJoin(subwords, "")); - } - return output_tokens; -} - -absl::StatusOr FastWordpieceTokenizer::Detokenize( - const absl::Span input) const { - SH_ASSIGN_OR_RETURN(std::vector output_tokens, - DetokenizeToTokens(input)); - return absl::StrJoin(output_tokens, " "); -} - -int FastWordpieceTokenizer::SkipTheRemainingOfWordAndTrailingWhiteSpaces( - absl::string_view input, int& cur_pos) const { - const int input_size = input.size(); - UChar32 cur_unicode_char; - int next_pos; - int end_of_word = cur_pos; - while (cur_pos < input_size) { - next_pos = cur_pos; - U8_NEXT(input, next_pos, input_size, cur_unicode_char); - if (u_isUWhiteSpace(cur_unicode_char)) { - cur_pos = next_pos; // Skip the whitespace as well. - // Break and return since we've met a word boundary. - break; - } - if (fast_wordpiece_tokenizer_utils::IsPunctuationOrChineseChar( - cur_unicode_char)) { - // Break and return since we've met a word boundary. We do not skip the - // punctuation character: that character may be a token by itself. - break; - } - end_of_word = next_pos; // Mark the exclusive end. - cur_pos = next_pos; // Skip the character. - } - return end_of_word; -} - -template -void FastWordpieceTokenizer::TokenizeTextImpl( - absl::string_view input_text, std::vector* output_pieces, - std::vector* output_ids, std::vector* output_start_offsets, - std::vector* output_end_offsets, bool* error) const { - static_assert(kGetPieces || kGetIds, - "At least one of `kGetPieces` and `kGetIds` should be true."); - if (input_text.empty()) { - return; - } - const int input_size = input_text.size(); - int prev_pos = -1; - int next_pos = 0; - int cur_pos = 0; - int original_num_tokens = - GetCurrentOutputSize(output_pieces, output_ids); - UChar32 prev_unicode_char; - UChar32 cur_unicode_char; - while (cur_pos < input_size) { - // Prevent looping without progress in cur_pos. - if (prev_pos == cur_pos && error != nullptr) { - *error = true; - return; - } - prev_pos = cur_pos; - - int cur_offset_in_input_word = 0; - // Tokenize the word starting at the current position. - auto cur_node = trie_->CreateTraversalCursorPointToRoot(); - int word_byte_length_so_far = 0; - int input_word_offset_in_text = cur_pos; - absl::string_view input_substr = input_text.substr(cur_pos); - // The trie matching loop below tokenizes and recognizes word pieces until - // 1. it steps over the input boundary, or - // 2. the length of the current word reaches 'max_bytes_per_token', or - // 3. it sees a whitespace / punctuation / unknown character. - int prev_pos_inner = -1; - while (cur_pos < input_size) { - // Prevent looping without progress in cur_pos. - if (prev_pos_inner == cur_pos && error != nullptr) { - *error = true; - return; - } - prev_pos_inner = cur_pos; - - prev_unicode_char = cur_unicode_char; - next_pos = cur_pos; - U8_NEXT(input_text, next_pos, input_text.length(), cur_unicode_char); - - if (word_byte_length_so_far + next_pos - cur_pos > - config_->max_bytes_per_token()) - break; - // Try matching one Unicode character from here. - while (!trie_->TryTraverseSeveralSteps( - cur_node, input_text.substr(cur_pos, next_pos - cur_pos))) { - // Trie cannot consume the whole Unicode character. We need to pop one - // or more longest-matching tokens off the beginning of the string - // represented by the current node. We then transit to the node pointed - // by the failure link, which represents the remaining suffix string - // after popping those matching prefix tokens. - // - // For example, if the current node is "abcdef", and we need to pop - // "ab", and "##cd" off the beginning, the failure link points to the - // node that represents "##ef". - if (!TryFollowFailureLinkAndCollectTokens( - input_substr, input_word_offset_in_text, - cur_offset_in_input_word, cur_node, output_pieces, output_ids, - output_start_offsets, output_end_offsets)) { - goto outside_trie_match_loop; - } - } - // Trie consumed the whole Unicode char and was able to traverse to a - // new node. We move forward the cursor to match the next character. - word_byte_length_so_far += next_pos - cur_pos; - cur_pos = next_pos; - } - outside_trie_match_loop: - if (cur_pos >= input_size) { - // Collect the remaining tokens stored on a path on the trie. - HandleTheRemainingStringOnTriePath( - input_substr, input_word_offset_in_text, cur_node, - original_num_tokens, cur_offset_in_input_word, output_pieces, - output_ids, output_start_offsets, output_end_offsets); - // Break as we've finished all characters. - break; - } - bool is_white_space = u_isUWhiteSpace(cur_unicode_char); - if (is_white_space || - fast_wordpiece_tokenizer_utils::IsPunctuationOrChineseChar( - cur_unicode_char) || - (cur_pos && fast_wordpiece_tokenizer_utils::IsPunctuationOrChineseChar( - prev_unicode_char))) { - // If the current Unicode character is a valid word boundary, collect the - // remaining tokens stored on a path on the trie. - absl::string_view cur_str = absl::string_view( - input_substr.data(), cur_pos - input_word_offset_in_text); - HandleTheRemainingStringOnTriePath( - cur_str, input_word_offset_in_text, cur_node, original_num_tokens, - cur_offset_in_input_word, output_pieces, output_ids, - output_start_offsets, output_end_offsets); - if (is_white_space) { - // Skip the whitespace. - cur_pos = next_pos; - } else if (cur_str.empty()) { - // If the remaining tokens are empty, it means we encountered an - // unmappable separator, so output an unknown token and continue. - cur_pos = next_pos; - ResetOutputAppendUnknownToken( - input_word_offset_in_text, (cur_pos - input_word_offset_in_text), - original_num_tokens, output_pieces, output_ids, - output_start_offsets, output_end_offsets); - } - // Continue in the outer while loop to process the remaining input. - continue; - } - - // Note that even with the following line removed, the code is still correct - // (i.e., Mutants is right). We keep this line for efficiency reasons: We - // have tested the current char, and it is not a whitespace or punctuation - // char. Hence it's safe to skip the current char; we don't want to test it - // again in the subsequent function. - cur_pos = next_pos; - int end_of_word = - SkipTheRemainingOfWordAndTrailingWhiteSpaces(input_text, cur_pos); - - // The current character is not a word boundary. The case is simple: We are - // at the start or middle of some word with unknown characters or exceeding - // the length limit. We map the entire word unk_token, skip the remaining - // portion, and continue. - ResetOutputAppendUnknownToken( - input_word_offset_in_text, (end_of_word - input_word_offset_in_text), - original_num_tokens, output_pieces, output_ids, output_start_offsets, - output_end_offsets); - } -} -// This function implements the new linear WordPiece algorithm. The overall -// design is illustrated as follows: -// -// * WordPiece tokenization works in a left-to-right longest-matching-first -// greedy manner, known as maximum matching. -// -// * We use a trie containing all pieces from the vocabulary. -// -// * We iterate the input text left-to-right, following the trie in search of -// longer and longer matches. -// -// * Challenge: When we fall off the trie matching, the best match is usually -// several characters back. -// -// * For example, assume the vocabulary is {a, ab, ##cd, ##efz, abcdefg}. -// If the input is "abcdefz", the trie matching stops at the position of -// "z". However, the longest match is "ab", which is 5 characters back. -// -// * Straightforward solution: Remember the last match while iterating on the -// trie. That gives us the longest match. Then we roll our string iterator -// backwards and reprocess the characters that weren't part of the match. It -// can be proved that the time complexity is quadratic. -// -// * For the example above, it will backtrack to the 3rd position and -// restart matching from "c", resulting in repetitive, wasteful iterations. -// -// * Optimized solution (the novel linear algorithm): Instead of having to -// reprocess the letters that didn't match, we can have the trie record -// (1) the longest-matching tokens that we would have identified (called -// "failure pops") and (2) a link pointing to a node (called "failure link") -// representing the state from where we can continue to match the next -// character. When trie matching cannot consume an input character, we perform -// a "failure transition" by (a) appending the failure pops to the tokenization -// result and (b) transiting through the failure link to a new state to -// continue the process. Our string iterator never backtracks, and it can be -// proved that we make at most `n` failure transitions in total in processing a -// string of length `n`. Therefore, the time complexity is linear. -// -// * For the same example above, when the trie matching fails at the -// character "z", the optimized solution is smart enough to know that the -// longest-matching tokens we can collect are ["ab", "##cd"]. It is also -// smart enough to set itself into such a state as if it has only seen and -// matched "##ef" so far. Now given the next character being "z", it -// immediately identifies the next matching token as "##efz". -template -void FastWordpieceTokenizer::TokenizeSingleWordImpl( - absl::string_view input_word, int input_word_offset_in_text, - std::vector* output_pieces, std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets) const { - static_assert(kGetPieces || kGetIds, - "At least one of `kGetPieces` and `kGetIds` should be true."); - if (input_word.empty()) { - return; - } - const int input_size = input_word.size(); - - // `original_num_tokens` stores the number of tokens in the output before - // tokenizing this `input_word`. This is needed because we attempt to tokenize - // `input_word` into word piece tokens and append the recognized tokens to the - // outputs on the fly. If we later find out that `input_word` cannot be - // tokenized into sub-tokens with the current vocabulary, we roll-back the - // output vectors (by removing those tentative tokens) based on - // `original_num_tokens` and appends the "unk_token". - int original_num_tokens = - GetCurrentOutputSize(output_pieces, output_ids); - - if (input_word.size() > config_->max_bytes_per_token()) { - ResetOutputAppendUnknownToken( - input_word_offset_in_text, input_size, original_num_tokens, - output_pieces, output_ids, output_start_offsets, output_end_offsets); - return; - } - - // `cur_offset_in_input_word` tracks the offset of the remaining portion of - // `input_word`, for which the tokens are yet to be recognized and outputted. - // Initially it just points to the start of the input. And it gets moved - // when more tokens are outputed. - // - // For example, suppose the vocab is {a,abcd,##b,##bc,##z}, and the input is - // "abcz". First `cur_offset_in_input_word` points to position 0, since we - // haven't ouputted any tokens. After the first token "a" is recognized and - // outputted, it moves passing the substring "a" to position 1. Then after the - // second token "##bc" is recognized and put to the outputs, it moves passing - // the substring "bc" to position 3. - // - // This variable is used to calculate the offsets of each word piece token. - // And since knowing their offsets in the input word, we're also able to get - // the token string without looking it up in the vocabulary table. This saves - // an extra look-up in hash table (saving time), and we don't even need to - // save the vocabulary table anymore (saving memory). - int cur_offset_in_input_word = 0; - - // Here is an example to illustrate the inference process. - // - // Suppose the vocabulary is {a,abcd,##b,##bc,##z}, and the suffix indicator - // is ##. Below is the trie built from that vocabulary: - // - // (a) (b) (c) (d) - // 0 ----- 3 ----- 4 ----- 5 ----- 6 - // (#)| - // 1 - // (#)| (b) (c) - // 2 ----- 7 ----- 8 - // | (z) - // + ----- 9 - // - // The algorithm constructs auxiliary structures on top of the trie to enable - // linear inference, which consist of two parts (let v denote a node): - // * failure links f(v), pointing to another node, - // * failure pops F(v), a list of tokens stored on node v. - // - // The table of str(v) (which is the string along the trie path from the root - // to node v), f(v), and F(v) for the above trie is as follows: - // - // v | 0 1 2 3 4 5 6 7 8 9 - // str(v)| "" # ## a ab abc abcd ##b ##bc ##z - // F(v)| [] [] [] [a] [a] [a] [abcd] [##b] [##bc] [##z] - // f(v)| null null null 2 7 8 2 2 2 null - // - // Please refer to `FastWordpieceTokenizerBuilder.h|cc` for detailed - // information on how failure links and failure pops are constructed. - // - // Let the input word be "abcz". Below is the inference process that is - // carried out by this method. - // - // Step | Char | Node transition | Output - // 0 | | 0 | [] - // 1 | a | goto(0,a) -> 3 | [] - // 2 | b | goto(3,b) -> 4 | [] - // 3 | c | goto(4,c) -> 5 | [] - // 4 | z | f(5) -> 8 | [a] - // | z | f(8) -> 2 | [a, ##bc] - // | z | goto(2,z) -> 9 | [a, ##bc] - // final | f(9) -> 2 | [a, ##bc, ##z] - // - // Notes: - // * In each step we match and process one input character. - // * goto(u,c) -> v: following the trie link with label c to transit from node - // u to node v. - // * f(u) -> v: following the failure link to transit from node u to node v. - // * The "final" step means that after processing all input characters, we - // keep transiting through the failure links until arriving at the node 2 - // that represents the suffix indicator "##". - // - // Please refer to the below code and comments. - - // Start from the root of the trie. - auto cur_node = trie_->CreateTraversalCursorPointToRoot(); - - for (auto ch : input_word) { - // Although the matching is on Unicode codepoints, it is equivalent to - // directly work with the utf-8 encoding bytes. - while (!trie_->TryTraverseOneStep(cur_node, ch)) { - // Trie cannot consume `ch`. As explained earlier (see "Optimized - // solution" above) we need to (1) pop one or more longest-matching tokens - // (i.e., failure pops) off the start of the string represented by the - // current node, and (2) transit through the failure link to a node that - // represents the remaining suffix string after popping those - // longest-matching prefix tokens. - if (!TryFollowFailureLinkAndCollectTokens( - input_word, input_word_offset_in_text, cur_offset_in_input_word, - cur_node, output_pieces, output_ids, output_start_offsets, - output_end_offsets)) { - // If unable to follow the failure link, it means that the current trie - // node doesn't have any matching prefix vocab tokens to pop. Since the - // next character is not associated with a valid trie edge, the entire - // word cannot be tokenized. - ResetOutputAppendUnknownToken( - input_word_offset_in_text, input_size, original_num_tokens, - output_pieces, output_ids, output_start_offsets, - output_end_offsets); - return; - } - } - // Trie consumed `ch` and was able to traverse to a new node. Continue and - // process the next character. - } - // Segment the remaining string on the trie into tokens and collect them, or - // determine that the word cannot be tokenized. - HandleTheRemainingStringOnTriePath( - input_word, input_word_offset_in_text, cur_node, original_num_tokens, - cur_offset_in_input_word, output_pieces, output_ids, output_start_offsets, - output_end_offsets); -} - -template -ABSL_ATTRIBUTE_ALWAYS_INLINE bool -FastWordpieceTokenizer::TryFollowFailureLinkAndCollectTokens( - absl::string_view input_word, int input_word_offset_in_text, - int& cur_offset_in_input_word, - trie_utils::DartsCloneTrieWrapper::TraversalCursor& node, - std::vector* output_pieces, std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets) const { - int cur_node_data; - if (trie_->TryGetData(node, cur_node_data)) { - // A shortcut to get f(cur_node) (i.e., the failure link) and F(cur_node) - // (i.e., failure pops) when `cur_node` has data. This results in ~10% - // speedup (statistically significant). - AppendTokenToOutput( - input_word, input_word_offset_in_text, cur_offset_in_input_word, - cur_node_data, output_pieces, output_ids, output_start_offsets, - output_end_offsets); - // Transit through the failure link. - trie_->SetTraversalCursor( - node, - config_->failure_struct_array()->Get(node.node_id)->failure_link()); - return true; - } - - const auto& node_aux = config_->failure_struct_array()->Get(node.node_id); - - if (node_aux->failure_link() == fast_wordpiece_tokenizer_utils::kNullNode) { - // No failure_link can be followed. - return false; - } - - // Collect the tokens (i.e., failure pops), represented by (offset, length) in - // a failure_pops pool (held by the config flatbuffer). - int failure_pops_offset, failure_pops_length; - fast_wordpiece_tokenizer_utils::GetFailurePopsOffsetAndLength( - node_aux->failure_pops_offset_length(), failure_pops_offset, - failure_pops_length); - const int failure_pops_end_offset = failure_pops_offset + failure_pops_length; - for (int offset_in_pool = failure_pops_offset; - offset_in_pool < failure_pops_end_offset; ++offset_in_pool) { - AppendTokenToOutput( - input_word, input_word_offset_in_text, cur_offset_in_input_word, - config_->failure_pops_pool()->Get(offset_in_pool), output_pieces, - output_ids, output_start_offsets, output_end_offsets); - } - - // Transit through the failure link. - trie_->SetTraversalCursor(node, node_aux->failure_link()); - return true; -} - -template -void FastWordpieceTokenizer::AppendTokenToOutput( - absl::string_view input_word, int input_word_offset_in_text, - int& cur_offset_in_input_word, int encoded_token_value, - std::vector* output_pieces, std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets) const { - auto token_id = - fast_wordpiece_tokenizer_utils::GetTokenId(encoded_token_value); - if constexpr (kGetIds) { - output_ids->push_back(token_id); - } - if constexpr (kGetPieces || kGetOffsets) { - // For suffix tokens, the length below is without the suffix indicator. - int token_substr_length = - fast_wordpiece_tokenizer_utils::GetTokenLength(encoded_token_value); - if (!cur_offset_in_input_word && - fast_wordpiece_tokenizer_utils::IsSuffixToken(encoded_token_value)) { - // This is a special case where `input_word` happens to start with the - // suffix indicator (e.g., "##") and a suffix token is recognized at the - // start (since `cur_offset_input_word == 0`). In this case, we need - // to adjust and add the length of the suffix indicator string. - token_substr_length += config_->suffix_indicator()->size(); - } - if constexpr (kGetPieces) { - // If token id is unk_token_id, it means that it is a dummy node for - // punctuations that are not contained in the vocabulary, we append - // the unk_token in this case. Otherwise, we - // get the subword string from `input_word` by the offset and length. - auto unk_token = config_->unk_token()->string_view(); - auto subword_str = - (token_id == config_->unk_token_id()) - ? absl::string_view(unk_token.data(), unk_token.size()) - : absl::string_view(input_word.data() + cur_offset_in_input_word, - token_substr_length); - output_pieces->emplace_back( - cur_offset_in_input_word - ? absl::StrCat(config_->suffix_indicator()->str(), subword_str) - : subword_str); - } - if constexpr (kGetOffsets) { - // Record the offsets relative to the start of the whole text. - output_start_offsets->push_back(input_word_offset_in_text + - cur_offset_in_input_word); - output_end_offsets->push_back(input_word_offset_in_text + - cur_offset_in_input_word + - token_substr_length); - } - cur_offset_in_input_word += token_substr_length; - } -} - -template -ABSL_ATTRIBUTE_ALWAYS_INLINE void -FastWordpieceTokenizer::HandleTheRemainingStringOnTriePath( - absl::string_view input_word, int input_word_offset_in_text, - trie_utils::DartsCloneTrieWrapper::TraversalCursor& cur_node, - int& original_num_tokens, int& cur_offset_in_input_word, - std::vector* output_pieces, std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets) const { - if (cur_node.node_id == trie_utils::DartsCloneTrieWrapper::kRootNodeId) { - // We've seen an empty input word. Just return. - return; - } - // Try handling the special case where the entire input word happens to be the - // suffix indicator (e.g., "##") itself. - if (TryHandleTheInputWordBeingSuffixIndicatorItself( - input_word, input_word_offset_in_text, cur_node, - cur_offset_in_input_word, original_num_tokens, output_pieces, - output_ids, output_start_offsets, output_end_offsets)) { - original_num_tokens = - GetCurrentOutputSize(output_pieces, output_ids); - return; - } - - // Handle the normal case because we need to collect the remaining tokens from - // the string represented by `cur_node` (i.e., on the trie path from the trie - // root to `cur_node`), or find out the word cannot be tokenized. - // - // See the example in the comments of this function in the header file. - // - // The tokenization is successful if and only if the entire string represented - // by `cur_node` can be segmented into consecutive matching tokens, resulting - // in the empty suffix string (e.g., "##"), which is represented by - // `trie_suffix_root_`. So we keep following the failure links and collecting - // failure pops tokens until we arrive at `trie_suffix_root_` or encounter a - // null failure link in the middle. - while (cur_node.node_id != config_->trie_suffix_root() && - cur_node.node_id != config_->trie_punct_failure_link_node()) { - if (!TryFollowFailureLinkAndCollectTokens( - input_word, input_word_offset_in_text, cur_offset_in_input_word, - cur_node, output_pieces, output_ids, output_start_offsets, - output_end_offsets)) { - // The remaining string cannot be tokenized, neither can the input word. - ResetOutputAppendUnknownToken( - input_word_offset_in_text, input_word.size(), original_num_tokens, - output_pieces, output_ids, output_start_offsets, output_end_offsets); - return; - } - } - // Arrive at `trie_suffix_root_`. - - // Update the `original_num_tokens`. - original_num_tokens = - GetCurrentOutputSize(output_pieces, output_ids); - - // Succeed and exit. -} - -template -void FastWordpieceTokenizer::ResetOutputAppendUnknownToken( - int input_word_offset_in_text, int input_size, int& original_num_tokens, - std::vector* output_pieces, std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets) const { - if constexpr (kGetPieces) { - output_pieces->resize(original_num_tokens + 1); - output_pieces->back() = config_->unk_token()->str(); - } - if constexpr (kGetIds) { - output_ids->resize(original_num_tokens + 1); - output_ids->back() = config_->unk_token_id(); - } - if constexpr (kGetOffsets) { - output_start_offsets->resize(original_num_tokens + 1); - output_start_offsets->back() = input_word_offset_in_text; - - output_end_offsets->resize(original_num_tokens + 1); - output_end_offsets->back() = input_word_offset_in_text + input_size; - } - - // Update `original_num_tokens` (since we have appended the "unk_token"). - ++original_num_tokens; -} - -template -ABSL_ATTRIBUTE_ALWAYS_INLINE bool -FastWordpieceTokenizer::TryHandleTheInputWordBeingSuffixIndicatorItself( - absl::string_view input_word, int input_word_offset_in_text, - const trie_utils::DartsCloneTrieWrapper::TraversalCursor& cur_node, - int& cur_offset_in_input_word, int original_num_tokens, - std::vector* output_pieces, std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets) const { - // Handle the special case where the input word is the suffix indicator (e.g., - // "##") itself. This is because that, after all the characters of an input - // word were successfully processed, if we ended by standing at - // `trie_suffix_root_` but did not recognize any new tokens, it can only be - // the case that the word is the suffix indicator string (e.g., "##") itself. - // For this case we output the pre-computed result. - if (cur_node.node_id != config_->trie_suffix_root()) { - // The input word is not the suffix indicator itself. - return false; - } - int cur_num_tokens = - GetCurrentOutputSize(output_pieces, output_ids); - if (cur_num_tokens != original_num_tokens) { - // The input word is not the suffix indicator itself. - return false; - } - - // The input word is the suffix indicator itself. Next we handle two cases. - if (config_->precomputed_result_for_suffix_indicator()->size() == 1 && - fast_wordpiece_tokenizer_utils::GetTokenId( - config_->precomputed_result_for_suffix_indicator()->Get(0)) == - config_->unk_token_id()) { - // Case 1: The suffix indicator string cannot be tokenized but has to be - // mapped to unk_token. - ResetOutputAppendUnknownToken( - input_word_offset_in_text, input_word.size(), original_num_tokens, - output_pieces, output_ids, output_start_offsets, output_end_offsets); - return true; - } - - // Case 2: The suffix indicator can be tokenized normally. - for (int encoded_token_value : - *config_->precomputed_result_for_suffix_indicator()) { - AppendTokenToOutput( - input_word, input_word_offset_in_text, cur_offset_in_input_word, - encoded_token_value, output_pieces, output_ids, output_start_offsets, - output_end_offsets); - } - return true; -} -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h index 562c2a495..20be8a497 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h @@ -15,246 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_H_ -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "tensorflow_text/core/kernels/darts_clone_trie_wrapper.h" -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_generated.h" -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_utils.h" - -namespace tensorflow { -namespace text { - -// Applies WordPiece tokenization with an existing WordPiece vocabulary. -// -// Example: -// input = unaffable -// output = un ##aff ##able -// -// One important edge case is that if the input word contains a Unicode -// character that is not seen in the vocabulary, the entire word is mapped -// to the unknown token, which is "" by default. Otherwise, in the "worst" -// case, the word is split into characters. -// -// This is based on the WordPiece/Subword tokenizer from tensor2tensor. -// https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py -class FastWordpieceTokenizer { - public: - // Creates an instance. - // - // Args: - // * config_flatbuffer: the pointer to the FastWordpieceTokenizerConfig - // flatbuffer, which is not owned by this instance and should be kept alive - // through the lifetime of the instance. - static absl::StatusOr Create( - const void* config_flatbuffer); - - // Tokenizes `input` into its word pieces (i.e., subword tokens) and - // appends the new tokens to the end of the outputs. - // When `config_->end_to_end() is `false`, `input` should be a single - // word (after pre-tokenization by whitespaces and/or punctuations). - // Otherwise, `input` should be general text consisting of potentially many - // words. - // - // The input should be UTF-8 but the tokenization is performed on Unicode - // codepoints. - // - // - // Args: - // * input: The UTF-8 string of an input. - // * output_pieces: The output tokens. - // * output_ids: The output token ids. - // * output_start_offsets: The start offsets of output tokens in the input - // text, in utf-8 bytes. - // * output_end_offsets: The end offsets of output tokens in the input - // text, in utf-8 bytes. - // * input_word_offset_in_text: The relative offset of the input word in - // the whole text. Only used when not using end-to-end tokenizer. - // * error: If not null, this will be set to true if the tokenizer failed to - // make progress in decoding the input. - // Note: the start offsets are inclusive and the end offsets are exclusive. - void Tokenize(absl::string_view input, - std::vector* output_pieces, - std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets, - int input_word_offset_in_text = 0, bool* error = nullptr) const; - - // An override not returning `output_pieces`. - void Tokenize(absl::string_view input, std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets, - int input_word_offset_in_text = 0) const; - - // An override only returning `output_ids`. - void Tokenize(absl::string_view input, std::vector* output_ids, - int input_word_offset_in_text = 0) const; - - // Detokenizes wordpiece ids into a vector of tokens. - absl::StatusOr> DetokenizeToTokens( - const absl::Span input) const; - - // Detokenizes wordpiece ids to a text. If the input string to the tokenizer - // is normalized and the tokenized wordpieces don't contain ``, the - // detokenized result of the tokenized wordpieces is the same as the original - // input text. - absl::StatusOr Detokenize( - const absl::Span input) const; - - private: - // The actual implementation of `Tokenize` when configured for single words. - // - // The template parameters `kGetPieces`, `kGetIds', and `kGetOffsets` control - // which parts of the output we generate. At least one of `kGetPieces` and - // `kGetIds` should be true. - template - void TokenizeSingleWordImpl(absl::string_view input_word, - int input_word_offset_in_text, - std::vector* output_pieces, - std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets) const; - - // The actual implementation of `Tokenize` when configured for general texts. - // - // The work of this method is equivalent to first splitting `input_text` into - // words (by splitting on punctuation and whitespaces, and next running - // `TokenizeSingleWordImpl` on each word. - template - void TokenizeTextImpl(absl::string_view input_text, - std::vector* output_pieces, - std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets, - bool* error) const; - - // Try following the failure link to make the transition when trie matching - // fails. - // - // If f(node) (i.e., failure link) is not null, it does the following: - // (1) collects tokens F(node) (i.e., failure pops) and appends to the end of - // `output_ids`, `output_pieces`, and/or `output_start_offsets` and - // `output_end_offsets`, - // (2) moves `cur_offset_in_input_word` accordingly to pass the collected - // tokens when `kGetPieces=true` or `kGetOffsets=true`, in order to - // calculate the start/end offsets of tokens and to get the token - // strings. Otherwise, `cur_offset_in_input_word` is ignored. - // (3) transits `node` to f(node) following the failure link, - // (4) returns true. - // - // If f(node) is null, it does not change anything and returns false. - // - // Args: - // * cur_offset_in_input_word: The current offset in `input_word` that - // corresponds to the start offset of the tokens that are going to be - // collected in this function. This value is used if 'kGetPieces=true' or - // 'kGetOffsets=true', and when so, this value will be updated accordingly - // after the new word piece tokens have been appended to the output. - template - bool TryFollowFailureLinkAndCollectTokens( - absl::string_view input_word, int input_word_offset_in_text, - int& cur_offset_in_input_word, - trie_utils::DartsCloneTrieWrapper::TraversalCursor& node, - std::vector* output_pieces, std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets) const; - - // Appends a word piece token (represented by `encoded_token_value`) to the - // output. - // - // Args: - // * cur_offset_in_input_word: The current offset in `input_word` that - // corresponds to the start offset of the wordpiece token. This value - // is used if `kGetPieces=true` or `kGetOffsets=true`, and when so, this - // value will be updated accordingly after the wordpiece token has been - // appended to the output. - // * encoded_token_value: the encoded value of the word piece token to be - // appended. See EncodeToken() in fast_wordpiece_tokenizer_utils.h. - template - void AppendTokenToOutput(absl::string_view input_word, - int input_word_offset_in_text, - int& cur_offset_in_input_word, - int encoded_token_value, - std::vector* output_pieces, - std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets) const; - - // This method is called when the trie matching loop encounters a word - // boundary (e.g., the end-of-input). This method segments the remaining - // string on the trie path into pieces and appends them to the outputs. If - // that is not possible with the current vocabulary, this method resets the - // outputs and appends unk_token. - // - // Example 1: suppose the vocabulary is {ab, abcd}. If the input word is "ab", - // after matching "ab", we processed all input characters and now meets the - // end-of-input. Note that the string "ab" is stored on the trie path that we - // just traversed along. This function recognizes it as the token "ab" and - // puts the token into the output as expected. - // - // Example 2: for the same vocabulary {ab, abcd}, suppose the input word is - // "abc". After the trie matching loop, we matched "abc" and encountered the - // end-of-input. Now the string "abc" is stored on the trie path, which we - // haven't segmented into tokens yet. So this function closes it by trying to - // segment "abc" into tokens. It fails since the remaining string "abc" cannot - // be tokenized into tokens given the vocabulary. In this case, it resets the - // outputs and appends unk_token at the end as expected. - template - void HandleTheRemainingStringOnTriePath( - absl::string_view input_word, int input_word_offset_in_text, - trie_utils::DartsCloneTrieWrapper::TraversalCursor& cur_node, - int& original_num_tokens, int& cur_offset_in_input_word, - std::vector* output_pieces, std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets) const; - - // Resets the output and appends unk_token. - // - // We call this method when we find that the input word cannot be tokenized. - // We clear all new tokens recognized so far and replace them with a single - // unk_token. - // - // Args: - // * input_word_offset_in_text: The offset of the current word in the - // input text. - // * input_size: The length of the current input word, in utf-8 bytes. - // * original_num_tokens: The original number of tokens in the output before - // we started the tokenization of the current input word. It is updated - // after this method. - template - void ResetOutputAppendUnknownToken( - int input_word_offset_in_text, int input_size, int& original_num_tokens, - std::vector* output_pieces, std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets) const; - - // Try handling the special case when the input word is the suffix indicator - // itself. If so, appends the precomputed result to output_pieces and - // output_ids, and returns true. Otherwise, it does nothing and returns false. - template - bool TryHandleTheInputWordBeingSuffixIndicatorItself( - absl::string_view input_word, int input_word_offset_in_text, - const trie_utils::DartsCloneTrieWrapper::TraversalCursor& cur_node, - int& cur_offset_in_input_word, int original_num_tokens, - std::vector* output_pieces, std::vector* output_ids, - std::vector* output_start_offsets, - std::vector* output_end_offsets) const; - - // Returns the position (in bytes) immediately after the end of the word. - int SkipTheRemainingOfWordAndTrailingWhiteSpaces(absl::string_view input, - int& cur_pos) const; - - // Points to the FastWordpieceTokenizer config flatbuffer (not owned). - const FastWordpieceTokenizerConfig* config_ = nullptr; - - // A wrapper to access the trie encoded inside the flatbuffer that `config_` - // points to. - std::unique_ptr trie_ = nullptr; -}; - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/fast_wordpiece_tokenizer.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_H_ diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel.cc b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel.cc deleted file mode 100644 index ca3e49435..000000000 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel.h" - -#include "tensorflow/core/framework/op_kernel.h" - -namespace tensorflow { -namespace text { - -REGISTER_KERNEL_BUILDER(Name(FastWordpieceTokenizeWithOffsetsOpKernel::OpName()) - .Device(tensorflow::DEVICE_CPU), - FastWordpieceTokenizeWithOffsetsOpKernel); - -REGISTER_KERNEL_BUILDER(Name(FastWordpieceDetokenizeOpKernel::OpName()) - .Device(tensorflow::DEVICE_CPU), - FastWordpieceDetokenizeOpKernel); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel.h b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel.h index 29542d92d..b2d0ae73a 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel.h +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel.h @@ -15,25 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_KERNEL_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_KERNEL_H_ -#include "tensorflow/lite/kernels/shim/tf_op_shim.h" -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h" - -namespace tensorflow { -namespace text { - -class FastWordpieceTokenizeWithOffsetsOpKernel - : public tflite::shim::TfOpKernel { - public: - using TfOpKernel::TfOpKernel; -}; - -class FastWordpieceDetokenizeOpKernel - : public tflite::shim::TfOpKernel { - public: - using TfOpKernel::TfOpKernel; -}; - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/fast_wordpiece_tokenizer_kernel.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_KERNEL_H_ diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h index efc26197a..8d0075b56 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h @@ -15,364 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_KERNEL_TEMPLATE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_KERNEL_TEMPLATE_H_ -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/lite/kernels/shim/op_kernel.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h" - -namespace tensorflow { -namespace text { - -// See `kDoc` data member for the documentation on this op kernel. -// -// This template class can be instantiated into a kernel for either TF or -// TFLite. See -// https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/kernels/shim -// for more info on how this works. -template -class FastWordpieceTokenizeWithOffsetsOp - : public tflite::shim::OpKernelShim { - private: - enum Inputs { kInputValues = 0, kWpModel }; - enum Outputs { - kOutputSubwords = 0, - kOutputIds, - kOutputRowSplits, - kStartValues, - kEndValues - }; - - using Shape = tflite::shim::Shape; - using typename tflite::shim::OpKernelShim::InitContext; - using typename tflite::shim::OpKernelShim::InvokeContext; - using typename tflite::shim::OpKernelShim::ShapeInferenceContext; - - public: - FastWordpieceTokenizeWithOffsetsOp() = default; - static constexpr char kOpName[] = "FastWordpieceTokenizeWithOffsets"; - static constexpr char kDoc[] = R"doc( - Tokenizes tokens into sub-word pieces based off of a vocabulary using the fast - linear WordPiece algorithm. - - `wordpiece_tokenize_with_offsets` returns the relative offsets. - - ### Example: - - ```python - >>> tokens = ['don', '\'t', 'treadness'] - >>> wordpiece, ids, row_splits, start, end = ( - ... fast_wordpiece_tokenize_with_offsets(tokens, model_buffer)) - >>> RaggedTensor.from_row_splits(wordpiece, row_splits) - [['don', '\'', 't'], ['tread', '##ness']] - >>> RaggedTensor.from_row_splits(ids, row_splits) - [[0, 1, 2], [3, 4]] # Dummy ids. - >>> RaggedTensor.from_row_splits(start, row_splits) - start = [[[0, 3, 4], [0, 5]]] - >>> RaggedTensor.from_row_splits(end, row_splits) - end = [[[3, 4, 5], [5, 10]]] - ``` - - Args: - input_values: 1D Tensor of strings to tokenize with. - wp_model: Buffer tensor for the FastWordpieceTokenizerConfig flatbuffer. - - Returns: - * output_values: 1D tensor containing the wordpieces for all input strings. - A 2D RaggedTensor can be constructed from this and output_row_splits. - * output_ids: 1D tensor containing the wordpiece ids for all input strings. - A 2D RaggedTensor can be constructed from this and output_row_splits. - * output_row_splits: 1D int tensor with the row splits that allow us to - build RaggedTensors from output_values, output_ids, start_values, and - end_values. - * start_values: 1D tensor containing the inclusive start byte offset for - each wordpiece in all input strings. Corresponds 1:1 with output_values. - A 2D RaggedTensor can be constructed from this and output_row_splits. - * end_values: 1D tensor containing the exclusive end byte offset for - each wordpiece in all input strings. Corresponds 1:1 with output_values. - A 2D RaggedTensor can be constructed from this and output_row_splits. - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Attrs() { return {}; } - - // Input tensors declaration (syntax: - // https://www.tensorflow.org/guide/create_op) - static std::vector Inputs(); - - // Output tensors declaration (syntax: - // https://www.tensorflow.org/guide/create_op) - static std::vector Outputs(); - - // Initializes the op - absl::Status Init(InitContext* context) { return absl::OkStatus(); } - - // Runs the operation - absl::Status Invoke(InvokeContext* context); - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* c); -}; - -////////////////////////// Implementation - -template -std::vector FastWordpieceTokenizeWithOffsetsOp::Inputs() { - return {"input_values: string", "wp_model: uint8"}; -} - -template -std::vector FastWordpieceTokenizeWithOffsetsOp::Outputs() { - return {"output_subwords: string", "output_ids: int64", - "output_row_splits: int64", "start_values: int64", - "end_values: int64"}; -} - -template -absl::Status FastWordpieceTokenizeWithOffsetsOp::Invoke( - InvokeContext* context) { - SH_ASSIGN_OR_RETURN(const auto input_values, context->GetInput(kInputValues)); - const auto& values_vec = input_values->template As(); - - SH_ASSIGN_OR_RETURN(const auto wp_model, context->GetInput(kWpModel)); - // OK to create on every call because FastWordpieceTokenizer is a - // lightweight, memory-mapped wrapper on `wp_model` tensor, and thus - // Create() is very cheap. - auto fast_wordpiece_tokenizer = - ::tensorflow::text::FastWordpieceTokenizer::Create( - wp_model->template Data().data()); - SH_RETURN_IF_ERROR(fast_wordpiece_tokenizer.status()); - - // TODO(xysong): Optimize based on which information below is requested. - std::vector subwords; - std::vector subword_ids; - std::vector begin_offset; - std::vector end_offset; - std::vector row_splits; - - row_splits.push_back(0); - - // Iterate through all the values and wordpiece tokenize them. - for (int i = 0; i < values_vec.Dim(0); ++i) { - // Tokenize into subwords and record the offset locations. - const int original_num_wordpieces = subwords.size(); - bool error = false; - fast_wordpiece_tokenizer->Tokenize(values_vec(i), &subwords, &subword_ids, - &begin_offset, &end_offset, - /*input_word_offset_in_text=*/0, &error); - if (error) { - return absl::InternalError( - "Failed to make any progress in tokenizing the input text."); - } - const int delta_num_wordpieces = subwords.size() - original_num_wordpieces; - - // Record the row splits. - row_splits.push_back(delta_num_wordpieces + row_splits.back()); - } - - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - subwords, kOutputSubwords, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - subword_ids, kOutputIds, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - row_splits, kOutputRowSplits, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - begin_offset, kStartValues, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - end_offset, kEndValues, context)); - - return absl::OkStatus(); -} - -template -absl::Status FastWordpieceTokenizeWithOffsetsOp::ShapeInference( - ShapeInferenceContext* c) { - using tflite::shim::Shape; - SH_ASSIGN_OR_RETURN(const Shape input_values_shape, - c->GetInputShape(kInputValues)); - SH_ASSIGN_OR_RETURN(const auto wp_model_shape, c->GetInputShape(kWpModel)); - const auto rank_1_shape = Shape({Shape::kUnknownDim}); - if (!input_values_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be rank 1: ", input_values_shape.ToString())); - } - if (!wp_model_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be rank 1: ", wp_model_shape.ToString())); - } - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputSubwords, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputIds, rank_1_shape)); - // row splits size - const int num_splits = Shape::AddDims(1, input_values_shape.Dim(0)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputRowSplits, Shape({num_splits}))); - SH_RETURN_IF_ERROR(c->SetOutputShape(kStartValues, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kEndValues, rank_1_shape)); - - return absl::OkStatus(); -} - - -// See `kDoc` data member for the documentation on this op kernel. -// -// This template class can be instantiated into a kernel for either TF or -// TFLite. See -// https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/kernels/shim -// for more info on how this works. -template -class FastWordpieceDetokenizeOp - : public tflite::shim::OpKernelShim { - private: - enum Inputs { kInputValues = 0, kInputRowSplits, kWpModel }; - enum Outputs { kOutputWords = 0 }; - - using Shape = tflite::shim::Shape; - using typename tflite::shim::OpKernelShim::InitContext; - using typename tflite::shim::OpKernelShim::InvokeContext; - using typename tflite::shim::OpKernelShim::ShapeInferenceContext; - - public: - FastWordpieceDetokenizeOp() = default; - static constexpr char kOpName[] = "TFText>FastWordpieceDetokenize"; - static constexpr char kDoc[] = R"doc( - Detokenizes sub-word ids into sentences. - - ### Example: - - ```python - >>> # Vocab of the model_buffer: ['a', 'ab', '##c', 'abc', '##d']. - >>> wordpiece_ids = [0, 1, 2, 3, 4] - >>> row_splits = [0, 3, 5] - >>> tokens = fast_wordpiece_tokenizer_detokenize(tokens, row_splits, model_buffer) - >>> tokens - ['a abc', 'abcd'] - ``` - - Args: - input_values: 1D Tensor of sub-word ids. - input_row_splits: 1D Tensor of row splits that denotes the boundary of each - sentence in the `input_values`. - wp_model: Buffer tensor for the FastWordpieceTokenizerConfig flatbuffer. - - Returns: - * output_values: 1D tensor containing all the sentences. - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Attrs() { return {}; } - - // Input tensors declaration (syntax: - // https://www.tensorflow.org/guide/create_op) - static std::vector Inputs(); - - // Output tensors declaration (syntax: - // https://www.tensorflow.org/guide/create_op) - static std::vector Outputs(); - - // Initializes the op - absl::Status Init(InitContext* context) { return absl::OkStatus(); } - - // Runs the operation - absl::Status Invoke(InvokeContext* context); - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* c); -}; - -////////////////////////// Implementation - -template -std::vector FastWordpieceDetokenizeOp::Inputs() { - return {"input_values: int32", "input_row_splits: int64", "wp_model: uint8"}; -} - -template -std::vector FastWordpieceDetokenizeOp::Outputs() { - return {"output_words: string"}; -} - -template -absl::Status FastWordpieceDetokenizeOp::Invoke(InvokeContext* context) { - SH_ASSIGN_OR_RETURN(const auto input_values, context->GetInput(kInputValues)); - const auto& values_vec = input_values->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_row_splits, - context->GetInput(kInputRowSplits)); - const auto& row_splits_vec = input_row_splits->template As(); - - SH_ASSIGN_OR_RETURN(const auto wp_model, context->GetInput(kWpModel)); - // OK to create on every call because FastWordpieceTokenizer is a - // lightweight, memory-mapped wrapper on `wp_model` tensor, and thus - // Create() is very cheap. - auto fast_wordpiece_tokenizer = - ::tensorflow::text::FastWordpieceTokenizer::Create( - wp_model->template Data().data()); - SH_RETURN_IF_ERROR(fast_wordpiece_tokenizer.status()); - - std::vector sentences; - - // Iterate through row_splits to split input_values. - for (int i = 0; i < row_splits_vec.Dim(0) - 1; ++i) { - auto single_input = - absl::Span(values_vec.Ptr() + row_splits_vec(i), - row_splits_vec(i + 1) - row_splits_vec(i)); - SH_ASSIGN_OR_RETURN(auto sentence, - fast_wordpiece_tokenizer->Detokenize(single_input)); - sentences.push_back(sentence); - } - - const int words_size = sentences.size(); - SH_ASSIGN_OR_RETURN(auto output_words, - context->GetOutput(kOutputWords, Shape({words_size}))); - auto output_words_vec = output_words->template As(); - - for (int i = 0; i < words_size; ++i) { - output_words_vec(i) = sentences[i]; - } - - return absl::OkStatus(); -} - -template -absl::Status FastWordpieceDetokenizeOp::ShapeInference( - ShapeInferenceContext* c) { - using tflite::shim::Shape; - SH_ASSIGN_OR_RETURN(const Shape input_values_shape, - c->GetInputShape(kInputValues)); - SH_ASSIGN_OR_RETURN(const Shape input_row_splits_shape, - c->GetInputShape(kInputRowSplits)); - SH_ASSIGN_OR_RETURN(const auto wp_model_shape, c->GetInputShape(kWpModel)); - const auto rank_1_shape = Shape({Shape::kUnknownDim}); - if (!input_values_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be rank 1: ", input_values_shape.ToString())); - } - if (!input_row_splits_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_row_splits_shape.ToString())); - } - if (!wp_model_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be rank 1: ", wp_model_shape.ToString())); - } - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputWords, rank_1_shape)); - return absl::OkStatus(); -} - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/fast_wordpiece_tokenizer_kernel_template.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_KERNEL_TEMPLATE_H_ diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model.fbs b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model.fbs deleted file mode 100644 index 3f508f677..000000000 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model.fbs +++ /dev/null @@ -1,68 +0,0 @@ -namespace tensorflow.text; - -struct FailureStruct { - // The failure link of node v, denoted as f(v). - failure_link: uint32; - - // The failure pops of node v, denoted as F(v). It is an encoded value of - // (offset, length) that represents a consecutive subarray in - // 'failure_pops_pool' (see FastWordpieceTokenizerConfig). - failure_pops_offset_length: uint32; -} - -table FastWordpieceTokenizerConfig { - // The trie data, in the format of darts_clone trie, as accepted by - // DartsCloneTrieWrapper::Create(). - trie_array: [uint32]; - - // The array of the failure structures. - failure_struct_array: [FailureStruct]; - - // The array holding the failure pops. - failure_pops_pool: [int]; - - // The trie suffix root node id. - trie_suffix_root: uint32; - - // Max size of the input token. If the input length is longer than this, it - // will be mapped to unk_token. - max_bytes_per_token: int; - - // Characters prepended to a wordpiece to indicate that it is a suffix to - // another subword, such as "##". - suffix_indicator: string; - - // The unknown token string. - unk_token: string; - - // The unkown token id. - unk_token_id: int; - - // The precomputed result for the input being the suffix indicator itself. - precomputed_result_for_suffix_indicator: [int]; - - // The node id of every punctuation's failure link. It is only used when - // end_to_end=true. - trie_punct_failure_link_node: uint32; - - // Whether to build end-to-end tokenizer for tokenizing general texts (as - // opposed to splitted single words). When it is true, the input text is first - // split into words on "punctuation"/whitespaces, and each word is further - // tokenized into subwords. - // Note that our definition of "punctuation" includes some special Chinese - // characters for compatibility with Bert. More details are available in - // `fast_wordpiece_tokenizer_utils::IsPunctuationOrChineseChar`. - end_to_end: bool; - - // Whether the tokenizer supports detokenization function. - support_detokenization: bool; - - // WordPiece Vocabulary. Note that we remove suffix indicator from suffix - // tokens for saving space. - vocab_array: [string]; - - // Whether the corresponding token in the vocab_array is a suffix token. - vocab_is_suffix_array: [bool]; -} - -root_type FastWordpieceTokenizerConfig; diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.cc b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.cc deleted file mode 100644 index 9467c4d6e..000000000 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.cc +++ /dev/null @@ -1,941 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.h" - -#include - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/status/status.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "absl/strings/strip.h" -#include /* cppitertools */ "imap.hpp" -#include "icu4c/source/common/unicode/umachine.h" -#include "icu4c/source/common/unicode/utf8.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow_text/core/kernels/darts_clone_trie_builder.h" -#include "tensorflow_text/core/kernels/darts_clone_trie_wrapper.h" -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_generated.h" -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_utils.h" -#include "tensorflow_text/core/kernels/sentence_fragmenter_v2.h" -#include "tensorflow_text/core/kernels/string_vocab.h" - -namespace tensorflow { -namespace text { -namespace { - -// A Unicode control char that never appears in the input as it is filtered -// during text normalization. It is used to build dummy nodes in the trie. -static constexpr char kInvalidControlChar = 0x11; - -// A wrapper of vocab tokens that will be used to build the trie. -class TrieVocabToken { - public: - TrieVocabToken(absl::string_view token, int token_id, - absl::string_view suffix_indicator) - : token_(std::string(token)), token_id_(token_id) { - if (!suffix_indicator.empty() && token_ != suffix_indicator && - absl::StartsWith(token_, suffix_indicator)) { - is_suffix_token_ = true; - actual_token_start_offset_ = suffix_indicator.size(); - } - // Iterate over the Unicode chars from the token, to initialize - // contains_punctuation_ and actual_token_unicode_len_. - int token_len = token.size(); - int cur_pos = actual_token_start_offset_; - UChar32 c; - while (cur_pos < token_len) { - U8_NEXT(token, cur_pos, token_len, c); - if (!contains_punctuation_ && - fast_wordpiece_tokenizer_utils::IsPunctuationOrChineseChar(c)) { - contains_punctuation_ = true; - } - ++actual_token_unicode_len_; - } - } - - absl::string_view Token() const { return token_; } - - int TokenId() const { return token_id_; } - - bool IsSuffixToken() const { return is_suffix_token_; } - - bool ContainsPunctuation() const { return contains_punctuation_; } - - int TokenUnicodeLengthWithoutSuffixIndicator() const { - return actual_token_unicode_len_; - } - - int TokenLengthWithoutSuffixIndicator() const { - return token_.size() - actual_token_start_offset_; - } - - private: - std::string token_; - - int token_id_ = -1; - - // By design, `is_suffix_token_`=false for the suffix indicator (e.g., "##") - // itself. - bool is_suffix_token_ = false; - - // The starting offset of the token string in `token_` without the suffix - // indicator. By design, `actual_token_start_offset_`=0 for the suffix - // indicator (e.g., "##") itself. - int actual_token_start_offset_ = 0; - - // Length of the actual token string in Unicode character. - int actual_token_unicode_len_ = 0; - - // True when the actual token string contains punctuation, e.g. "test.x", - // "##.", ".test", "...", "!", etc. - bool contains_punctuation_ = false; -}; - -// The failure struct to store failure links and failure pops. -struct FailureStruct { - // The failure link, denoted as f(v), of each node v. - // - // Null node is represented by fast_wordpiece_tokenizer_utils::kNullNode. - uint32_t failure_link = fast_wordpiece_tokenizer_utils::kNullNode; - - // The failure pop list, denoted as F(v), of a node v. - // - // It is stored as a pair of offset and length that represents a continuous - // vector in `failure_pops_pool_`. This pair is encoded using - // EncodeFailurePopList() in fast_wordpiece_tokenizer_utils.h. - uint32_t failure_pops_offset_length = - fast_wordpiece_tokenizer_utils::kNullFailurePopsList; -}; - -// Builds the FastWordpieceTokenizer model. -class FastWordpieceBuilder { - public: - // When no_pretokenization is false, we split the input string by punctuation - // chars (in addition to whitespaces) and then tokenize it to wordpieces. - absl::Status BuildModel(const std::vector& vocab, - int max_bytes_per_token, - absl::string_view suffix_indicator, - absl::string_view unk_token, - bool no_pretokenization, - bool support_detokenization); - - absl::StatusOr ExportToFlatBuffer() const; - - private: - absl::StatusOr> PrepareVocabTokensToBuildTrie(); - - absl::Status ConstructTrie( - const std::vector& tokens_to_build_trie); - - absl::Status BuildFailureStructure( - const std::vector& tokens_to_build_trie); - - // Builds the set of outgoing edge labels for each trie node and returns a - // mapping (node_id -> set). Used in BuildFailureStructure(). - absl::StatusOr>> - BuildOutgoingEdgeLabelsForTrie( - const std::vector& tokens_to_build_trie); - - // Builds the set of outgoing edge labels for nodes along the trie path of - // `vocab_token`. Used in BuildOutgoingEdgeLabelsForTrie(). - absl::Status BuildOutgoingEdgeLabelsAlongVocabToken( - const TrieVocabToken& vocab_token, - std::vector>& node_outgoing_edge_labels); - - // Assigns failure link f(cur_node) to `failure_link` and populates failure - // pops F(cur_node) (based on `one_step_pops` and - // `parent_failure_pops_offset_length`). - absl::Status AssignFailureLinkAndPops(uint32_t cur_node, - uint32_t failure_link, - const std::vector& one_step_pops, - int parent_failure_pops_offset_length); - - // If `failure_pops_offset_length` encodes a valid failure pop list, appends - // the failure pop list to the end of `out_failure_pops`. Otherwise, does - // nothing. - void GetFailurePopsAndAppendToOut(uint32_t failure_pops_offset_length, - std::vector& out_failure_pops); - - absl::Status PrecomputeResultForSuffixIndicator(); - - inline void BreakTrieLinkFromParentToChild(uint32_t child_node_id) { - // In trie, the least significant 8 bits encode the label of the trie link - // from the parent to the node itself. - // - // Reference: - // https://github.com/s-yata/darts-clone/blob/e40ce4627526985a7767444b6ed6893ab6ff8983/include/darts.h#L65-L70. - // - // For example, if there is a trie link `u` -> `v` with label (say) 'a' - // (ASCII 97 or 0x61), then the least significant 8 bits of node `v` will be - // 0x61. By erasing its least significant 8 bits to 0, it effectively - // prevents the node from being reachable from its parent, i.e. breaking the - // trie link from the parent to the node itself. - trie_array_[child_node_id] &= 0xFFFFFF00; - } - - inline void EraseValueOfNode(uint32_t node_id) { - // In trie, the 9th least significant bit of a node's value marks whether - // the node has a leaf node (i.e., having a value stored on the node). - // - // Reference: - // https://github.com/s-yata/darts-clone/blob/e40ce4627526985a7767444b6ed6893ab6ff8983/include/darts.h#L54-L58 - // - // By setting the 9th least significant bit to 0, it effectively erases any - // value (i.e., token id in our case) associated with the node. - trie_array_[node_id] &= 0xFFFFFEFF; - } - - std::unique_ptr vocab_; - - int max_bytes_per_token_ = -1; - - std::string suffix_indicator_; - - std::string unk_token_; - - int unk_token_id_ = -1; - - // A wrapper to access the trie encoded by `trie_array_`. - absl::optional trie_; - - // The actual data of the trie. - std::vector trie_array_; - - // The "suffix_root" node on the trie whose trie path (from the root to the - // node) is the suffix indicator string. - uint32_t trie_suffix_root_ = fast_wordpiece_tokenizer_utils::kNullNode; - - // The dummy node to serve as the failure link of punctuation nodes. - uint32_t trie_punct_failure_link_node_ = - fast_wordpiece_tokenizer_utils::kNullNode; - - // Whether to build the end-to-end tokenizer that tokenizes general texts. - // When set to false, it splits the input on punctuation/whitespace and treat - // each punctuation as an independent word. - bool no_pretokenization_; - - // Whether the tokenizer supports the detokenization function. - bool support_detokenization_; - - std::vector failure_struct_array_; - - // Each element in the failure pops pool is an encoded vocab token. - // See EncodeToken() in fast_wordpiece_tokenizer_utils.h. - std::vector failure_pops_pool_; - - // The precomputed result for the suffix indicator. Each element in the - // failure pops pool is an encoded vocab token. See EncodeToken() in - // fast_wordpiece_tokenizer_utils.h. - std::vector precomputed_result_for_suffix_indicator_; - - // The mapping from node id to whether the corresponding token is a - // punctuation char. - absl::flat_hash_map node_id_is_punc_map_; -}; - -absl::Status FastWordpieceBuilder::BuildModel( - const std::vector& vocab, int max_bytes_per_token, - absl::string_view suffix_indicator, absl::string_view unk_token, - bool no_pretokenization, bool support_detokenization) { - unk_token_ = std::string(unk_token); - suffix_indicator_ = std::string(suffix_indicator); - max_bytes_per_token_ = max_bytes_per_token; - no_pretokenization_ = no_pretokenization; - support_detokenization_ = support_detokenization; - - vocab_ = std::make_unique(vocab); - if (vocab_->Size() != vocab.size()) { - return absl::FailedPreconditionError( - "Tokens in the vocabulary must be unique."); - } - - // Determine `unk_token_id_`. - const absl::optional unk_token_id = vocab_->LookupId(unk_token_); - if (!unk_token_id.has_value()) { - return absl::FailedPreconditionError("Cannot find unk_token in the vocab!"); - } - unk_token_id_ = *unk_token_id; - - // Construct the trie and the failure structure. - SH_ASSIGN_OR_RETURN(auto tokens_to_build_trie, - PrepareVocabTokensToBuildTrie()); - SH_RETURN_IF_ERROR(ConstructTrie(tokens_to_build_trie)); - SH_RETURN_IF_ERROR(BuildFailureStructure(tokens_to_build_trie)); - - // Precompute the result when the input is the suffix indicator string itself. - SH_RETURN_IF_ERROR(PrecomputeResultForSuffixIndicator()); - - return absl::OkStatus(); -} - -absl::StatusOr> -FastWordpieceBuilder::PrepareVocabTokensToBuildTrie() { - // To simplify the inference (fewer corner cases), - // * We ensure that `trie_suffix_root_` is always available on the trie. - // * We ensure that `trie_suffix_root_` does not have data (i.e., the suffix - // indicator string is not in the set of the keys of the trie). - // * We don't actually add the end-of-input symbol "$" but use an alternative - // logic. See FastWordpieceTokenizer::HandleTheRemainingStringOnTriePath(). - - if (vocab_->Size() > fast_wordpiece_tokenizer_utils::kMaxSupportedVocabSize) { - return absl::FailedPreconditionError( - absl::StrCat("Vocab size exceeds the max supported (", - fast_wordpiece_tokenizer_utils::kMaxSupportedVocabSize, - "). Found vocab size: ", vocab_->Size(), ".")); - } - - // Collect a subset of tokens (and variations) to build the trie. - std::vector tokens_to_build_trie; - tokens_to_build_trie.reserve(vocab_->Size()); - for (int token_id = 0; token_id < vocab_->Size(); ++token_id) { - const absl::optional word = vocab_->LookupWord(token_id); - if (!word.has_value()) { - return absl::FailedPreconditionError( - "Impossible. `token_id` is definitely within the range of vocab " - "token ids; hence LookupWord() should always succeed."); - } - if (word->empty()) { - // It does not make sense to add the empty string "" to the vocabulary. In - // addition, darts_clone does not allow an empty Trie key. - // - // We allow this only for compatibility with the original Wordpiece - // algorithm. - LOG(WARNING) - << "The empty string is found in the vocabulary, which takes place " - "in the token id space but will never be used in the result. " - "Consider cleaning it from the vocabulary."; - continue; - } - if (*word == suffix_indicator_) { - // In real-life cases, no need to add the suffix indicator string (e.g., - // "##") to the vocabulary. - // - // We allow this only for compatibility with the original Wordpiece - // algorithm. - LOG(WARNING) - << "The empty suffix token is found in the vocabulary, which takes " - "place in token id space but will (almost) never be used in the " - "result. Consider cleaning it from the vocabulary."; - - // The token id of the suffix indicator is used only when the input is - // the suffix indicator itself. That case is handled elsewhere, in - // PrecomputeResultForSuffixIndicator(). - // - // Therefore, we don't insert the suffix indicator string as a key into - // the trie. As a result, `trie_suffix_root_` node will never have data. - - continue; - } - TrieVocabToken vocab_token(*word, token_id, suffix_indicator_); - if (vocab_token.TokenLengthWithoutSuffixIndicator() > - fast_wordpiece_tokenizer_utils::kMaxVocabTokenLengthInUTF8Bytes) { - return absl::FailedPreconditionError(absl::StrCat( - "Vocab token utf8 length (excluding suffix indicator) exceeds the " - "max supported (", - fast_wordpiece_tokenizer_utils::kMaxVocabTokenLengthInUTF8Bytes, - "). The vocab token is: ", *word, - " with utf8 length (excluding suffix indicator): ", - vocab_token.TokenLengthWithoutSuffixIndicator(), ".")); - } - // Skip word that contains punctuation but is not a punctuation itself. - // , , ##. are skipped in this step. - if (!no_pretokenization_ && vocab_token.ContainsPunctuation() && - (vocab_token.TokenUnicodeLengthWithoutSuffixIndicator() > 1 || - vocab_token.IsSuffixToken())) { - continue; - } - tokens_to_build_trie.emplace_back(vocab_token); - } - - if (tokens_to_build_trie.empty()) { - return absl::FailedPreconditionError( - "No valid vocab tokens were found to build the trie."); - } - if (!suffix_indicator_.empty()) { - const bool suffix_token_exists = std::any_of( - tokens_to_build_trie.begin(), tokens_to_build_trie.end(), - [](const TrieVocabToken& token) { return token.IsSuffixToken(); }); - if (!suffix_token_exists) { - // No suffix tokens in the vocab. That would lead to no trie node for - // the suffix indicator, which creates corner cases in the inference. - // To prevent that, we add a dummy suffix token, e.g., "##" + - // kInvalidControlChar (if the suffix indicator is "##"), which is never - // matched during inference. - tokens_to_build_trie.emplace_back(TrieVocabToken( - absl::StrCat(suffix_indicator_, std::string(1, kInvalidControlChar)), - unk_token_id_, suffix_indicator_)); - } - } - - if (!no_pretokenization_) { - // Special treatment for all Unicode punctuation chars that are not already - // in the trie. - // The maximum codepoint in Unicode is 0x0010FFFF. - for (UChar32 cp = 1; cp <= 0x0010FFFF; ++cp) { - if (!U_IS_UNICODE_CHAR(cp) || - !fast_wordpiece_tokenizer_utils::IsPunctuationOrChineseChar(cp)) { - continue; - } - // Get the UTF8 encoding of the codepoint cp. - char buf[4]; - int len = 0; - U8_APPEND_UNSAFE(buf, len, cp); - absl::string_view buf_view(buf, len); - // Set the token id of punctuation chars that don't exist in the vocab as - // unk_token_id_. - if (!vocab_->LookupId(buf_view)) { - TrieVocabToken vocab_token(buf_view, unk_token_id_, suffix_indicator_); - tokens_to_build_trie.emplace_back(vocab_token); - } - } - // Insert a dummy node to serve as the failure link targets for punctuation - // nodes. - tokens_to_build_trie.emplace_back(TrieVocabToken( - std::string(1, kInvalidControlChar), unk_token_id_, suffix_indicator_)); - } - return tokens_to_build_trie; -} - -absl::Status FastWordpieceBuilder::ConstructTrie( - const std::vector& tokens_to_build_trie) { - std::vector keys; - std::vector values; - for (const TrieVocabToken& vocab_token : tokens_to_build_trie) { - keys.emplace_back(vocab_token.Token()); - SH_ASSIGN_OR_RETURN(int encoded_value, - fast_wordpiece_tokenizer_utils::EncodeToken( - vocab_token.TokenId(), - vocab_token.TokenLengthWithoutSuffixIndicator(), - vocab_token.IsSuffixToken())); - values.push_back(encoded_value); - } - SH_ASSIGN_OR_RETURN(trie_array_, - trie_utils::BuildDartsCloneTrie(keys, values)); - SH_ASSIGN_OR_RETURN( - trie_utils::DartsCloneTrieWrapper trie, - trie_utils::DartsCloneTrieWrapper::Create(trie_array_.data())); - trie_.emplace(std::move(trie)); - - if (trie_array_.size() > - fast_wordpiece_tokenizer_utils::kMaxSupportedTrieSize) { - return absl::FailedPreconditionError(absl::StrCat( - "Not supported since the constructed Darts trie size (", - trie_array_.size(), ") is greater than the maximum supported size (", - fast_wordpiece_tokenizer_utils::kMaxSupportedTrieSize, ").")); - } - - // Locate the trie suffix root. - auto node = trie_->CreateTraversalCursorPointToRoot(); - if (!trie_->TryTraverseSeveralSteps(node, suffix_indicator_)) { - return absl::FailedPreconditionError( - "Cannot locate trie_suffix_root_. This should never happen."); - } - trie_suffix_root_ = node.node_id; - - if (!no_pretokenization_) { - // Locate the dummy node for the failure link for punctuation nodes. - node = trie_->CreateTraversalCursorPointToRoot(); - if (!trie_->TryTraverseSeveralSteps(node, - std::string(1, kInvalidControlChar))) { - return absl::FailedPreconditionError( - "Cannot locate the dummy node for the failure link for punctuation " - "nodes. This should never happen."); - } - trie_punct_failure_link_node_ = node.node_id; - - // We make `trie_punct_failure_link_node_` a standalone dummy node. - EraseValueOfNode(trie_punct_failure_link_node_); - BreakTrieLinkFromParentToChild(trie_punct_failure_link_node_); - } - return absl::OkStatus(); -} - -absl::Status FastWordpieceBuilder::BuildOutgoingEdgeLabelsAlongVocabToken( - const TrieVocabToken& vocab_token, - std::vector>& node_outgoing_edge_labels) { - const absl::string_view token = vocab_token.Token(); - trie_utils::DartsCloneTrieWrapper::TraversalCursor cur_node; - int char_pos = 0; - trie_->SetTraversalCursor(cur_node, trie_->kRootNodeId); - while (char_pos < token.size()) { - const char edge_label = token[char_pos]; - node_outgoing_edge_labels[cur_node.node_id].insert(edge_label); - if (!trie_->TryTraverseOneStep(cur_node, edge_label)) { - // Should never happen, since we built trie using all of `vocab_token`. - return absl::FailedPreconditionError(absl::StrCat( - "Cannot traverse from parent id ", cur_node.node_id, - " to child following the edge with label value of ", - static_cast(edge_label), - " when processing a vocabulary token with token ID ", - vocab_token.TokenId(), " (0-based). This error happened at ", - "position ", char_pos, " (0-based) of the token. Before that, ", - "the prefix \"", token.substr(0, char_pos), - "\" of the token had been processed. This should never happen. ", - "This probably indicates that there are some unicode ", - "issues (e.g., byte '\\x0' in the middle) for the above ", - "mentioned token in the vocabulary file. All bytes of this ", - "questionable token (ID ", vocab_token.TokenId(), ") are: [", - absl::StrJoin( - iter::imap([](auto ch) { return static_cast(ch); }, - vocab_token.Token()), - ", "), - "].")); - } - ++char_pos; - } - // Record whether the current node represents a punctuation char in the map. - node_id_is_punc_map_[cur_node.node_id] = - !vocab_token.IsSuffixToken() && vocab_token.ContainsPunctuation() && - vocab_token.TokenUnicodeLengthWithoutSuffixIndicator() == 1; - return absl::OkStatus(); -} - -absl::StatusOr>> -FastWordpieceBuilder::BuildOutgoingEdgeLabelsForTrie( - const std::vector& tokens_to_build_trie) { - std::vector> node_outgoing_edge_labels( - trie_array_.size()); - const std::string dummy_token_for_trie_punct_failure_link_node = - std::string(1, kInvalidControlChar); - for (const TrieVocabToken& vocab_token : tokens_to_build_trie) { - if (vocab_token.Token() == dummy_token_for_trie_punct_failure_link_node) - continue; - SH_RETURN_IF_ERROR(BuildOutgoingEdgeLabelsAlongVocabToken( - vocab_token, node_outgoing_edge_labels)); - } - return node_outgoing_edge_labels; -} - -// Computes failure links and failure pops using BFS traversal. -absl::Status FastWordpieceBuilder::BuildFailureStructure( - const std::vector& tokens_to_build_trie) { - // Build the set of outgoing edge labels for each trie node (node_id -> - // set). This is needed by BFS because darts-clone does not provide an - // API to enumerate the outgoing links for a node. - SH_ASSIGN_OR_RETURN( - std::vector> node_outgoing_edge_labels, - BuildOutgoingEdgeLabelsForTrie(tokens_to_build_trie)); - - failure_struct_array_.resize(trie_array_.size()); - // Initialize the BFS queue. - std::queue bfs_queue({trie_->kRootNodeId}); - if (trie_suffix_root_ != trie_->kRootNodeId) { - // When `suffix_indicator_` is empty, `trie_suffix_root_` will collapse - // with root. In this case, we don't visit it twice. - // - // In addition, we have ensured that `trie_suffix_root_` will never be null. - // See PrepareVocabTokensToBuildTrie(). - bfs_queue.push(trie_suffix_root_); - } - - // The BFS loop. - while (!bfs_queue.empty()) { - uint32_t parent_id = bfs_queue.front(); - bfs_queue.pop(); - - // Explore the children of the parent node. - // - // Fix the iteration order of the outgoing edges to ensure that the model is - // always built in the same way (i.e., visiting nodes in the same order). - std::vector outgoing_labels_sorted( - node_outgoing_edge_labels[parent_id].begin(), - node_outgoing_edge_labels[parent_id].end()); - std::sort(outgoing_labels_sorted.begin(), outgoing_labels_sorted.end()); - for (const char edge_label : outgoing_labels_sorted) { - auto child_node = trie_->CreateTraversalCursor(parent_id); - if (!trie_->TryTraverseOneStep(child_node, edge_label)) { - // Should never happen, due to how we built `node_outgoing_edge_labels`; - // see BuildOutgoingEdgeLabelsAlongVocabToken(). - return absl::FailedPreconditionError(absl::StrCat( - "Failed to traverse to child following edge ", - absl::string_view(&edge_label, 1), " at parent ", parent_id, ".")); - } - if (child_node.node_id == trie_suffix_root_) { - // Avoid visiting `trie_suffix_root_` twice. - continue; - } - - // For the child node v, compute failure link f(v) and failure pops F(v). - // - // In the comments below, str(v) is the string on the path from the trie - // root to the node v, and V is the vocabulary used to build the trie. - - int child_data_value = -1; - if (trie_->TryGetData(child_node, child_data_value)) { - uint32_t failure_link = trie_suffix_root_; - // Check whether the current node represents a punctuation char. - // Since the current node has data and thus corresponds to some token, - // it must be in the map `node_id_is_punc_map_` - if (!node_id_is_punc_map_.contains(child_node.node_id)) { - return absl::FailedPreconditionError( - "Failed to find if an end node in the trie is a punctuation char " - "in node_id_is_punc_map_. It should never happen."); - } - if (!no_pretokenization_ && - node_id_is_punc_map_.at(child_node.node_id)) { - // For end-to-end tokenizer, we set the failure link node of every - // punctuation char as a special node trie_punct_failure_link_node_ - // which is a dummy node (no parent, no descendants, failure link is - // null). Hence, by detecting the landing node, we know we just - // matched a punctuation char. We then split it as a single word. - failure_link = trie_punct_failure_link_node_; - } - // Case 1 (easy): str(v) is in V. Assume that during tokenization of a - // word, we reached node v, but can't continue further, because the - // current char from the input word does not match any of the edges - // outgoing from v. In that case, str(v) is already the max match, so - // it's the only wordpiece we add to the list of wordpieces we committed - // to. Hence, F(v) = [str(v)]. The next wordpiece from the current word - // is a suffix, so we move to node f(v) = trie_suffix_root_, which - // represents the suffix indicator (e.g., "##"), from where we continue - // the match process. In summary, we have: - // * f(v) = trie_suffix_root_. - // * F(v) = [str(v)]. - SH_RETURN_IF_ERROR(AssignFailureLinkAndPops( - /*cur_node=*/child_node.node_id, /*failure_link=*/failure_link, - /*one_step_pops=*/{child_data_value}, - /*parent_failure_pops_offset_length=*/ - fast_wordpiece_tokenizer_utils::kNullFailurePopsList)); - bfs_queue.push(child_node.node_id); - continue; - } - - // Case 2 (complex): str(v) is not in V. - // - // Consider the same scenario as in Case 1, where we can't continue - // further from v, but now, str(v) is not a valid wordpiece. Instead, - // we need to consider the wordpieces that the MaxMatch algorithm would - // generate for the beginning of str(v) (these wordpieces are stored in - // F(v)). f(v) (the state we transit to) should correspond to the trie - // node for the remaining suffix of str(v). - // - // We could compute F(v) and f(v) by running the original WordPiece - // algorithm. Instead, we do it even faster, by using F(u) and f(u) (the - // similar info for the parent node u). Intuitively F(v) consists of (1) - // the tokens from F(u) and (2) the possible tokens that the MaxMatch - // algorithm would generate for str(f(u)).c, where str(f(u)) is the suffix - // of str(u) not covered by the concatenation of the tokens from F(u), "." - // means concatenation, and c is the edge label character from u to v. - // - // - // Let u be the parent node, and c be the edge label from u to v. To - // compute f(v) and F(v), the loop below uses a node variable z (called - // `itr_node`) and a list G (called `one_steps_pops`). Initially, z is set - // to be f(u), and G is empty. - // 1. If z is null, f(v) will be null, too (see Note 2 below for what - // this means). We're done. - // 2. Check if there is a trie edge out of node z, for label c, leading - // to node goto(z, c). If so, set f(v) = goto(z,c) and F(v) = F(u) + G. - // We're done and break. - // 3. Otherwise, collect the pop tokens (by G = G + F(z)) and - // follows the failure link (by z = f(z)). - // 4. Goes to Step 1 and continue the loop. - // - // Note 1: processing node v depends on the info for nodes z that are - // closer to the root than v. Due to our use of the BFS traversal, that - // info is guaranteed to exist when we examine node v. - // - // Note 2: f(v) is null means that during the tokenization process of some - // input word, if the trie matching cannot continue at node v, there are - // no failure links that we can follow, and (it can be proved that in such - // a case) the input word can't be tokenized with the current vocab. - // - // For formal discussions and proofs, please refer to the academic paper - // https://arxiv.org/abs/2012.15524 - const FailureStruct& parent_fs = failure_struct_array_[parent_id]; - if (parent_fs.failure_link != fast_wordpiece_tokenizer_utils::kNullNode) { - std::vector one_step_pops; - auto itr_node = trie_->CreateTraversalCursor(parent_fs.failure_link); - while (true) { - if (trie_->TryTraverseOneStep(itr_node, edge_label)) { - // Set the failure link and failure pops for `child_node`. - SH_RETURN_IF_ERROR(AssignFailureLinkAndPops( - /*cur_node=*/child_node.node_id, - /*failure_link=*/itr_node.node_id, one_step_pops, - parent_fs.failure_pops_offset_length)); - break; - } - const FailureStruct& itr_node_fs = - failure_struct_array_[itr_node.node_id]; - if (itr_node_fs.failure_link == - fast_wordpiece_tokenizer_utils::kNullNode) { - // Cannot follow anymore: failure link of `child_node` will be null. - break; - } - // Append the failure pops of `itr_node` to `one_step_pops`. - GetFailurePopsAndAppendToOut(itr_node_fs.failure_pops_offset_length, - one_step_pops); - // Follow the failure link. - trie_->SetTraversalCursor(itr_node, itr_node_fs.failure_link); - } - } - - bfs_queue.push(child_node.node_id); - } - } - - if (!no_pretokenization_ && !suffix_indicator_.empty()) { - // Rewire trie links along suffix_indicator_. - // If the suffix indicator contains a punctuation char, let `u`--(`c`)-->`v` - // be the first trie edge along the suffix indicator such that the edge - // label (i.e. `c`) is a punctuation char. Note that `u`, `v` are trie - // nodes. `c` is the edge label. We make the following change: - // - // Case 1: if `u` is the root, we remove the trie edge from `v` to its child - // along the suffix indicator. - // Case 2: if `u` is not the root, we remove the trie edge from `u` to `v`. - // - // Example 1: if suffix_indicator_ is "##" (as in BERT), we remove the trie - // link from "#" to "##". The goal here is to make sure we match the - // punctuation character "#" as a token by itself, without matching "##" - // (as we split by punctuation, "##" is not a valid token). - // Example 2: if suffix_indicator is "foo#", we remove the trie link from - // "foo" to "foo#". - int cur_pos = 0; - int next_pos = 0; - bool prev_node_id_is_root = false; - auto node = trie_->CreateTraversalCursorPointToRoot(); - UChar32 c; - int suffix_indicator_length = suffix_indicator_.size(); - while (cur_pos < suffix_indicator_length) { - next_pos = cur_pos; - U8_NEXT(suffix_indicator_, next_pos, suffix_indicator_length, c); - prev_node_id_is_root = (node.node_id == trie_->kRootNodeId); - absl::string_view cur_unicode_char(suffix_indicator_.data() + cur_pos, - next_pos - cur_pos); - if (!trie_->TryTraverseSeveralSteps(node, cur_unicode_char)) { - return absl::FailedPreconditionError( - "Cannot locate a character in suffix_indicator_. It should never " - "happen."); - } - if (fast_wordpiece_tokenizer_utils::IsPunctuationOrChineseChar(c)) { - // If the previous node is a root node, read the next char to break the - // link from the current punctuation char to its next child node. - if (prev_node_id_is_root) { - cur_pos = next_pos; - U8_FWD_1(suffix_indicator_, next_pos, suffix_indicator_length); - const absl::string_view next_unicode_char( - suffix_indicator_.data() + cur_pos, next_pos - cur_pos); - auto child_node = node; - if (!trie_->TryTraverseSeveralSteps(child_node, next_unicode_char)) { - return absl::FailedPreconditionError( - "Cannot locate a character in suffix_indicator_. It should " - "never happen."); - } - BreakTrieLinkFromParentToChild(child_node.node_id); - } else { - BreakTrieLinkFromParentToChild(node.node_id); - } - break; - } - cur_pos = next_pos; - } - } - return absl::OkStatus(); -} - -absl::Status FastWordpieceBuilder::AssignFailureLinkAndPops( - uint32_t cur_node, uint32_t failure_link, - const std::vector& one_step_pops, - int parent_failure_pops_offset_length) { - if (failure_link == fast_wordpiece_tokenizer_utils::kNullNode) { - return absl::OkStatus(); - } - FailureStruct& cur_node_fs = failure_struct_array_[cur_node]; - cur_node_fs.failure_link = failure_link; - - // Let v be `cur_node` and u be the parent node. - if (one_step_pops.empty()) { - // Case 1: F(v) = F(u). So we just share the same vector. - cur_node_fs.failure_pops_offset_length = parent_failure_pops_offset_length; - } else { - // Case 2: F(v) = F(u) + `one_step_pops`. We need to create a new vector and - // append to `failure_pops_pool_`. - const int failure_pops_offset = failure_pops_pool_.size(); - if (failure_pops_offset > - fast_wordpiece_tokenizer_utils::kMaxSupportedFailurePoolOffset) { - return absl::FailedPreconditionError(absl::StrCat( - "Failure pops list offset is ", failure_pops_offset, - ", which exceeds maximum supported offset ", - fast_wordpiece_tokenizer_utils::kMaxSupportedFailurePoolOffset, - ". The vocabulary seems to be too large to be supported.")); - } - // First copy F(u). - GetFailurePopsAndAppendToOut(parent_failure_pops_offset_length, - failure_pops_pool_); - // Then append `one_step_pops`. - failure_pops_pool_.insert(failure_pops_pool_.end(), one_step_pops.begin(), - one_step_pops.end()); - const int failure_pops_length = - failure_pops_pool_.size() - failure_pops_offset; - if (failure_pops_length > - fast_wordpiece_tokenizer_utils::kMaxFailurePopsListSize) { - // This should not happen, because `kBitsToEncodeFailurePopsListSize` is - // set to be less than or equal to `kBitsToEncodeVocabTokenLength` (see - // fast_wordpiece_tokenizer_utils.h). - return absl::FailedPreconditionError(absl::StrCat( - "Failure pops list size is ", failure_pops_length, - ", which exceeds maximum supported size ", - fast_wordpiece_tokenizer_utils::kMaxFailurePopsListSize, ".")); - } - - cur_node_fs.failure_pops_offset_length = - fast_wordpiece_tokenizer_utils::EncodeFailurePopList( - failure_pops_offset, failure_pops_length); - } - return absl::OkStatus(); -} - -void FastWordpieceBuilder::GetFailurePopsAndAppendToOut( - uint32_t failure_pops_offset_length, std::vector& out_failure_pops) { - if (failure_pops_offset_length == - fast_wordpiece_tokenizer_utils::kNullFailurePopsList) { - return; - } - int failure_pops_offset, failure_pops_length; - fast_wordpiece_tokenizer_utils::GetFailurePopsOffsetAndLength( - failure_pops_offset_length, failure_pops_offset, failure_pops_length); - out_failure_pops.insert( - out_failure_pops.end(), failure_pops_pool_.begin() + failure_pops_offset, - failure_pops_pool_.begin() + failure_pops_offset + failure_pops_length); -} - -absl::Status FastWordpieceBuilder::PrecomputeResultForSuffixIndicator() { - std::vector subwords; - std::vector begin_offset; - std::vector end_offset; - int num_word_pieces; - // Use the original WordPiece implementation. - LookupStatus status = WordpieceTokenize( - suffix_indicator_, max_bytes_per_token_, /*max_chars_per_subtoken=*/-1, - suffix_indicator_, /*use_unknown_token=*/true, unk_token_, - /*split_unknown_characters=*/false, vocab_.get(), &subwords, - &begin_offset, &end_offset, &num_word_pieces); - precomputed_result_for_suffix_indicator_.reserve(subwords.size()); - if (!status.success) { - return absl::FailedPreconditionError(status.error_msg); - } - for (int i = 0; i < subwords.size(); ++i) { - const absl::optional subword_id = vocab_->LookupId(subwords[i]); - if (!subword_id.has_value()) { - return absl::FailedPreconditionError( - "Impossible because `subwords[i]` must be in the vocabulary!"); - } - TrieVocabToken token(subwords[i], *subword_id, suffix_indicator_); - SH_ASSIGN_OR_RETURN( - int encoded_value, - fast_wordpiece_tokenizer_utils::EncodeToken( - token.TokenId(), token.TokenLengthWithoutSuffixIndicator(), - token.IsSuffixToken())); - precomputed_result_for_suffix_indicator_.push_back(encoded_value); - } - return absl::OkStatus(); -} - -absl::StatusOr FastWordpieceBuilder::ExportToFlatBuffer() const { - flatbuffers::FlatBufferBuilder builder; - - const auto trie_array = builder.CreateVector(trie_array_); - std::vector failure_struct_fbs_vector; - failure_struct_fbs_vector.reserve(failure_struct_array_.size()); - for (const auto& item : failure_struct_array_) { - failure_struct_fbs_vector.emplace_back(item.failure_link, - item.failure_pops_offset_length); - } - const auto failure_structure_array = - builder.CreateVectorOfStructs(failure_struct_fbs_vector); - const auto failure_pops_pool = builder.CreateVector(failure_pops_pool_); - const auto precomputed_result_for_suffix_indicator = - builder.CreateVector(precomputed_result_for_suffix_indicator_); - const auto suffix_indicator = builder.CreateString(suffix_indicator_); - const auto unk_token = builder.CreateString(unk_token_); - - std::vector> vocab_fbs_vector; - std::vector vocab_is_suffix_fbs_vector; - - if (support_detokenization_) { - vocab_fbs_vector.reserve(vocab_->Size()); - for (int i = 0; i < vocab_->Size(); ++i) { - const absl::optional word = vocab_->LookupWord(i); - if (!word.has_value()) { - return absl::FailedPreconditionError( - "Impossible. `token_id` is definitely within the range of vocab " - "token ids; hence LookupWord() should always succeed."); - } - absl::string_view token = word.value(); - bool is_suffix_token = false; - if (!suffix_indicator_.empty() && token != suffix_indicator_ && - absl::StartsWith(token, suffix_indicator_)) { - is_suffix_token = true; - // For suffix tokens, we remove the suffix indicator to save spac and - // for ease of use in detokenization (where the suffix indicator will be - // stripped anyway). - token = token.substr(suffix_indicator_.size()); - } - vocab_fbs_vector.emplace_back(builder.CreateString(token)); - vocab_is_suffix_fbs_vector.emplace_back(is_suffix_token); - } - } - - auto vocab_array = builder.CreateVector(vocab_fbs_vector); - auto vocab_is_suffix_array = builder.CreateVector(vocab_is_suffix_fbs_vector); - - FastWordpieceTokenizerConfigBuilder wtcb(builder); - wtcb.add_trie_array(trie_array); - wtcb.add_failure_struct_array(failure_structure_array); - wtcb.add_failure_pops_pool(failure_pops_pool); - wtcb.add_trie_suffix_root(trie_suffix_root_); - wtcb.add_trie_punct_failure_link_node(trie_punct_failure_link_node_); - - wtcb.add_max_bytes_per_token(max_bytes_per_token_); - wtcb.add_suffix_indicator(suffix_indicator); - wtcb.add_unk_token(unk_token); - wtcb.add_unk_token_id(unk_token_id_); - wtcb.add_precomputed_result_for_suffix_indicator( - precomputed_result_for_suffix_indicator); - wtcb.add_end_to_end(!no_pretokenization_); - wtcb.add_support_detokenization(support_detokenization_); - wtcb.add_vocab_array(vocab_array); - wtcb.add_vocab_is_suffix_array(vocab_is_suffix_array); - FinishFastWordpieceTokenizerConfigBuffer(builder, wtcb.Finish()); - return std::string(reinterpret_cast(builder.GetBufferPointer()), - builder.GetSize()); -} -} // namespace - -absl::StatusOr BuildModelAndExportToFlatBuffer( - const std::vector& vocab, int max_bytes_per_token, - absl::string_view suffix_indicator, absl::string_view unk_token, - bool no_pretokenization, bool support_detokenization) { - FastWordpieceBuilder builder; - SH_RETURN_IF_ERROR(builder.BuildModel(vocab, max_bytes_per_token, - suffix_indicator, unk_token, - no_pretokenization, - support_detokenization)); - SH_ASSIGN_OR_RETURN(std::string flatbuffer, builder.ExportToFlatBuffer()); - return flatbuffer; -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.h b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.h index 080d4ab76..74fa28daa 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.h +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.h @@ -15,39 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_MODEL_BUILDER_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_MODEL_BUILDER_H_ -#include -#include +#include "tensorflow/core/kernels/text/fast_wordpiece_tokenizer_model_builder.h" -#include "absl/status/statusor.h" - -namespace tensorflow { -namespace text { - -// Builds a FastWordpieceTokenizer model in flatbuffer format. -// -// Args: -// * vocab: The WordPiece vocabulary. -// * max_bytes_per_token: The max size of the input token. If the input -// length is longer than this, it will be mapped to unk_token. -// * suffix_indicator: Characters prepended to a wordpiece to indicate that -// it is a suffix to another subword, such as "##". -// * unk_token: The unknown token string. -// * no_pretokenization: Whether to pretokenize on punctuation & whitespace. -// Set to `false` when the model is used for general text end-to-end -// tokenization, which combines pre-tokenization (splitting text into words -// on punctuation/whitespaces) and WordPiece (breaking words into subwords) -// into one pass. -//. * support_detokenization: Whether to enable the detokenization function. -// Setting it to true expands the size of the flatbuffer. As a reference, -// When using 120k multilingual BERT WordPiece vocab, the flatbuffer's size -// increases from ~5MB to ~6MB. -// Returns: -// The bytes of the flatbuffer that stores the model. -absl::StatusOr BuildModelAndExportToFlatBuffer( - const std::vector& vocab, int max_bytes_per_token, - absl::string_view suffix_indicator, absl::string_view unk_token, - bool no_pretokenization = false, bool support_detokenization = false); -} // namespace text -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FASt_WORDPIECE_TOKENIZER_MODEL_BUILDER_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_MODEL_BUILDER_H_ diff --git a/tensorflow_text/core/kernels/sentence_fragmenter_v2_kernel.cc b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_generated.h similarity index 60% rename from tensorflow_text/core/kernels/sentence_fragmenter_v2_kernel.cc rename to tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_generated.h index f50d41a31..ddb91e0f3 100644 --- a/tensorflow_text/core/kernels/sentence_fragmenter_v2_kernel.cc +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_generated.h @@ -12,16 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "tensorflow_text/core/kernels/sentence_fragmenter_v2_kernel.h" +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_MODEL_GENERATED_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_MODEL_GENERATED_H_ -#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/text/fast_wordpiece_tokenizer_model_generated.h" -namespace tensorflow { -namespace text { - -REGISTER_KERNEL_BUILDER(Name(SentenceFragmenterV2OpKernel::OpName()) - .Device(tensorflow::DEVICE_CPU), - SentenceFragmenterV2OpKernel); - -} // namespace text -} // namespace tensorflow +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_MODEL_GENERATED_H_ diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_test.cc b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_test.cc deleted file mode 100644 index fea96e3ef..000000000 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_test.cc +++ /dev/null @@ -1,2554 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h" - -#include -#include -#include "absl/flags/flag.h" -#include "icu4c/source/common/unicode/uchar.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.h" - -namespace tensorflow { -namespace text { -namespace { - -using ::testing::AnyOf; -using ::testing::ElementsAre; - -constexpr char kTestConfigPath[] = - "tensorflow_text/python/ops/test_data/" - "fast_wordpiece_tokenizer_model.fb"; - -TEST(FastWordpieceTokenizerTest, LoadAndTokenize) { - std::string config_flatbuffer; - auto status = tensorflow::ReadFileToString( - tensorflow::Env::Default(), kTestConfigPath, &config_flatbuffer); - ASSERT_TRUE(status.ok()); - - // The config_flatbuffer used here is built from the following config: - // * vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - // "##ghz", ""} - // * unk_token = "" - // * suffix_indicator = "##" - // * max_bytes_per_token = 100 - ASSERT_OK_AND_ASSIGN( - auto tokenizer, FastWordpieceTokenizer::Create(config_flatbuffer.data())); - - std::string input = "abcdefghz"; - std::vector output_tokens; - std::vector output_ids; - std::vector output_start_offsets; - std::vector output_end_offsets; - tokenizer.Tokenize(input, &output_tokens, &output_ids, &output_start_offsets, - &output_end_offsets); - EXPECT_THAT(output_tokens, ElementsAre("abc", "##de", "##f", "##ghz")); - EXPECT_THAT(output_ids, ElementsAre(1, 3, 6, 7)); - EXPECT_THAT(output_start_offsets, ElementsAre(0, 3, 5, 6)); - EXPECT_THAT(output_end_offsets, ElementsAre(3, 5, 6, 9)); -} - -using TestPunctuationVersionMismatch = testing::TestWithParam; - -TEST_P(TestPunctuationVersionMismatch, Test) { - // The config_flatbuffer used here is built from the following config: - // * vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - // "##ghz", ""} - // * unk_token = "" - // * suffix_indicator = "##" - // * max_bytes_per_token = 100 - // * end_to_end = True - - const std::string kTestConfigUnicodePath = GetParam(); - - // We test the new punctuation symbol: \341\255\277, which was available in - // Unicode 16: https://www.fileformat.info/info/unicode/char//1b7f/index.htm, - // but not in 15.1. - // We also test an existing punctuation symbol ">". - std::string input = "abc>abc\341\255\277abc"; - - std::string config_flatbuffer; - auto status = tensorflow::ReadFileToString( - tensorflow::Env::Default(), kTestConfigUnicodePath, &config_flatbuffer); - ASSERT_TRUE(status.ok()); - - ASSERT_OK_AND_ASSIGN( - auto tokenizer, FastWordpieceTokenizer::Create(config_flatbuffer.data())); - - std::vector output_tokens; - std::vector output_ids; - std::vector output_start_offsets; - std::vector output_end_offsets; - tokenizer.Tokenize(input, &output_tokens, &output_ids, &output_start_offsets, - &output_end_offsets); - - // If the runtime environment has unicode <=15.1, "\341\255\277" is not a - // punctuation, so "abc\341\255\277abc" is one token. - // If the runtime environment has unicode >=16.0, "\341\255\277" is a - // punctuation, so tokens are "abc", "", "abc" - EXPECT_THAT(output_tokens.size(), AnyOf(3, 5)); - if (!u_ispunct(0x1b7f)) { - // We have a runtime environment of unicode <= 15.1. - EXPECT_THAT(output_tokens, ElementsAre("abc", "", "")); - EXPECT_THAT(output_ids, ElementsAre(1, 8, 8)); - EXPECT_THAT(output_start_offsets, ElementsAre(0, 3, 4)); - EXPECT_THAT(output_end_offsets, ElementsAre(3, 4, 13)); - } else { - // We have a runtime environment of unicode >= 16.0. - EXPECT_THAT(output_tokens, - ElementsAre("abc", "", "abc", "", "abc")); - EXPECT_THAT(output_ids, ElementsAre(1, 8, 1, 8, 1)); - EXPECT_THAT(output_start_offsets, ElementsAre(0, 3, 4, 7, 10)); - EXPECT_THAT(output_end_offsets, ElementsAre(3, 4, 7, 10, 13)); - } -} - -INSTANTIATE_TEST_SUITE_P(FastWordpieceTokenizerPunctuationTest, - TestPunctuationVersionMismatch, - testing::Values( - // Unicode v 15.1 config - "tensorflow_text/python/ops/test_data/" - "fast_wordpiece_tokenizer_model_ver_15_1.fb", - // Unicode v 16.0 config - "tensorflow_text/python/ops/test_data/" - "fast_wordpiece_tokenizer_model_ver_16_0.fb")); - -template -std::string ListToString(const std::vector& list) { - return absl::StrCat("[", absl::StrJoin(list, ", "), "]"); -} - -// Testing spec struct for parameterized tests. -struct Spec { - friend std::ostream& operator<<(std::ostream& os, const Spec& s) { - return os << "vocab: " << ListToString(s.vocab) << ", " - << "unk_token:" << s.unk_token << ", " - << "suffix_indicator:" << s.suffix_indicator << ", " - << "max_bytes_per_token:" << s.max_bytes_per_token << ", " - << "input:" << s.input << ", " - << "expected_tokens:" << ListToString(s.expected_tokens) << ", " - << "expected_token_ids:" << ListToString(s.expected_token_ids) - << ", " - << "expected_token_start_offsets:" - << ListToString(s.expected_token_start_offsets) << ", " - << "expected_token_end_offsets:" - << ListToString(s.expected_token_end_offsets) << std::endl; - } - - std::vector vocab; - std::string unk_token; - std::string suffix_indicator; - int max_bytes_per_token; - std::string input; - std::vector expected_tokens; - std::vector expected_token_ids; - std::vector expected_token_start_offsets = {}; - std::vector expected_token_end_offsets = {}; - // Only used when detokenizing the tokenized ids back to text. - std::string expected_detokenized_text; -}; - -// Parameterized tests specs for Tokenize() when input is a single word. -const std::vector& GetTestSpecsForTokenizeSingleWord() { - static const std::vector& v = *new std::vector{ - // Test suite 1, normal vocabulary. - // Test 0: Empty input. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "", - .expected_tokens = {}, - .expected_token_ids = {}, - .expected_token_start_offsets = {}, - .expected_token_end_offsets = {}, - }, - // Test 1: Basic. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcdefghz", - .expected_tokens = {"abc", "##de", "##f", "##ghz"}, - .expected_token_ids = {1, 3, 6, 7}, - .expected_token_start_offsets = {0, 3, 5, 6}, - .expected_token_end_offsets = {3, 5, 6, 9}, - }, - // Test 2: Collect more tokens at the end. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcdef", - .expected_tokens = {"abc", "##de", "##f"}, - .expected_token_ids = {1, 3, 6}, - .expected_token_start_offsets = {0, 3, 5}, - .expected_token_end_offsets = {3, 5, 6}, - }, - // Test 3: Unseen character alone. Result is . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "X", - .expected_tokens = {""}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {1}, - }, - // Test 4: Unseen character at the beginning. Result is . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "Xde", - .expected_tokens = {""}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {3}, - }, - // Test 5: Unseen character in the middle. Result is . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcXde", - .expected_tokens = {""}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {6}, - }, - // Test 6: Unseen character at the end. Result is . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcX", - .expected_tokens = {""}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {4}, - }, - // Test 7: Input has leading suffix indicator. Result is normal. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##deh", - .expected_tokens = {"##deh"}, - .expected_token_ids = {5}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {5}, - }, - // Test 8: Input has the leading suffix indicator. Vocab has "#" and - // "###". Result is normal. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "#", "###", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##deh", - .expected_tokens = {"##deh"}, - .expected_token_ids = {5}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {5}, - }, - // Test 9: Input is the suffix indicator itself. Result is . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##", - .expected_tokens = {""}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {2}, - }, - // Test 10: [PAD] is in the vocabulary. Input is [PAD]. - { - .vocab = {"[pad]", "a", "abc", "abcdefghi", "##de", "##defgxy", - "##deh", "##f", "##ghz", "#", "###", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "[pad]", - .expected_tokens = {"[pad]"}, - .expected_token_ids = {0}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {5}, - }, - // Test 11: [PAD] is not in the vocabulary. Input is [PAD]. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "#", "###", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "[pad]", - .expected_tokens = {""}, - .expected_token_ids = {10}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {5}, - }, - - // Test suite 2, input contains #. - // Test 12: Input is #. Result is . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "#", - .expected_tokens = {""}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {1}, - }, - // Test 13: Input is #. Result is not . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "#", "###", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "#", - .expected_tokens = {"#"}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {1}, - }, - // Test 14: Input is #. The suffix indicator is in the vocab. Result is - // not . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "#", "###", "##", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "#", - .expected_tokens = {"#"}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {1}, - }, - // Test 15: Input is the suffix indicator itself. Result is . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "#", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##", - .expected_tokens = {""}, - .expected_token_ids = {9}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {2}, - }, - // Test 16: Input is the suffix indicator itself. Result is not . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "#", "###", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##", - .expected_tokens = {"#", "###"}, - .expected_token_ids = {8, 9}, - .expected_token_start_offsets = {0, 1}, - .expected_token_end_offsets = {1, 2}, - }, - // Test 17: Input is the suffix indicator itself. The suffix indicator is - // in the vocab. Result is not . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "#", "###", "##", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##", - .expected_tokens = {"##"}, - .expected_token_ids = {10}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {2}, - }, - // Test 18: Input is ###. Result is . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "#", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "###", - .expected_tokens = {""}, - .expected_token_ids = {9}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {3}, - }, - // Test 19: Input is ###. Result is not . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "#", "###", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "###", - .expected_tokens = {"###"}, - .expected_token_ids = {9}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {3}, - }, - // Test 20: Input is ###. The suffix indicator is in the vocab. Result is - // not . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "#", "###", "##", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "###", - .expected_tokens = {"###"}, - .expected_token_ids = {9}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {3}, - }, - // Test 21: Input is ####. Result is not . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "#", "###", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "####", - .expected_tokens = {"###", "###"}, - .expected_token_ids = {9, 9}, - .expected_token_start_offsets = {0, 3}, - .expected_token_end_offsets = {3, 4}, - }, - // Test 22: Input is ####. The suffix indicator is in the vocab. Result - // is not . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "#", "###", "##", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "####", - .expected_tokens = {"###", "###"}, - .expected_token_ids = {9, 9}, - .expected_token_start_offsets = {0, 3}, - .expected_token_end_offsets = {3, 4}, - }, - - // Test suite 3, the vocabulary contains empty tokens ("", "##"). - // Test 23: The empty prefix token ("") and the empty suffix token ("##") - // are in the vocabulary. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "", "##", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcdefghz", - .expected_tokens = {"abc", "##de", "##f", "##ghz"}, - .expected_token_ids = {1, 3, 6, 7}, - .expected_token_start_offsets = {0, 3, 5, 6}, - .expected_token_end_offsets = {3, 5, 6, 9}, - }, - // Test 24: The empty prefix token ("") and the empty suffix ("##") token - // are in the vocabulary. Input is empty. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "", "##", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "", - .expected_tokens = {}, - .expected_token_ids = {}, - .expected_token_start_offsets = {}, - .expected_token_end_offsets = {}, - }, - // Test 25: The empty prefix token ("") and the empty suffix token ("##") - // are in the vocabulary. Input is the suffix indicator. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "", "##", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##", - .expected_tokens = {"##"}, - .expected_token_ids = {9}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {2}, - }, - // Test 26: The empty prefix token ("") and the empty suffix token ("##") - // are in the vocabulary. There are vocab tokens after the empty vocab - // tokens in the vocab. Result is one vocab token. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "", "##", "xyz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "xyz", - .expected_tokens = {"xyz"}, - .expected_token_ids = {10}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {3}, - }, - // Test 27: The empty prefix token ("") and the empty suffix ("##") token - // are in the vocabulary. There are vocab tokens after the empty vocab - // tokens in the vocab. Result has multiple tokens. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "", "##", "xy", "##z", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "xyz", - .expected_tokens = {"xy", "##z"}, - .expected_token_ids = {10, 11}, - .expected_token_start_offsets = {0, 2}, - .expected_token_end_offsets = {2, 3}, - }, - // Test 28: The empty prefix token ("") and the empty suffix token ("##") - // are in the vocabulary. Input has the leading suffix indicator. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "", "##", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##deh", - .expected_tokens = {"##deh"}, - .expected_token_ids = {5}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {5}, - }, - - // Test suite 4, No suffix tokens in the vocabulary. - // Test 29: No suffix tokens in the vocabulary. Result is normal. - { - .vocab = {"a", "abc", "abcdefghi", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abc", - .expected_tokens = {"abc"}, - .expected_token_ids = {1}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {3}, - }, - // Test 30: No suffix tokens in the vocabulary. Result is . - { - .vocab = {"a", "abc", "de", "abcdefghi", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcde", - .expected_tokens = {""}, - .expected_token_ids = {4}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {5}, - }, - // Test 31: No suffix tokens in the vocabulary. A different input. Result - // is . - { - .vocab = {"a", "abc", "de", "abcdefghi", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcdz", - .expected_tokens = {""}, - .expected_token_ids = {4}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {5}, - }, - // Test 32: No suffix tokens in the vocabulary. Input is #. Result is - // - { - .vocab = {"a", "abc", "de", "abcdefghi", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "#", - .expected_tokens = {""}, - .expected_token_ids = {4}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {1}, - }, - // Test 33: No suffix tokens in the vocabulary. Input is #. Result is not - // . - { - .vocab = {"a", "abc", "de", "abcdefghi", "", "#"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "#", - .expected_tokens = {"#"}, - .expected_token_ids = {5}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {1}, - }, - // Test 34: No suffix tokens in the vocabulary. Vocab has the suffix - // indicator. Input is #. - { - .vocab = {"a", "abc", "de", "abcdefghi", "", "##"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "#", - .expected_tokens = {""}, - .expected_token_ids = {4}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {1}, - }, - // Test 35: No suffix tokens in the vocabulary. Input is ##. Result is - // . - { - .vocab = {"a", "abc", "de", "abcdefghi", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##", - .expected_tokens = {""}, - .expected_token_ids = {4}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {2}, - }, - // Test 36: No suffix tokens in the vocabulary. Vocab has the suffix - // indicator. Input is #. Result is . - { - .vocab = {"a", "abc", "de", "abcdefghi", "", "##"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "#", - .expected_tokens = {""}, - .expected_token_ids = {4}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {1}, - }, - // Test 37: No suffix tokens in the vocabulary. Vocab has the suffix - // indicator. Input is ##. - { - .vocab = {"a", "abc", "de", "abcdefghi", "", "##"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##", - .expected_tokens = {"##"}, - .expected_token_ids = {5}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {2}, - }, - // Test 38: No suffix tokens in the vocabulary. Vocab has '#'. Input is - // ##. Result is . - { - .vocab = {"a", "abc", "de", "abcdefghi", "", "#"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##", - .expected_tokens = {""}, - .expected_token_ids = {4}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {2}, - }, - // Test 39: No suffix tokens in the vocabulary. Vocab has the suffix - // indicator and "#". Input is ##. - { - .vocab = {"a", "abc", "de", "abcdefghi", "", "##", "#"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##", - .expected_tokens = {"##"}, - .expected_token_ids = {5}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {2}, - }, - // Test 40: No suffix tokens in the vocabulary. Input is ###. Result is - // . - { - .vocab = {"a", "abc", "de", "abcdefghi", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "###", - .expected_tokens = {""}, - .expected_token_ids = {4}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {3}, - }, - // Test 41: No suffix tokens in the vocabulary. Vocab has '#'. Input is - // ###. Result is . - { - .vocab = {"a", "abc", "de", "abcdefghi", "", "#"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "###", - .expected_tokens = {""}, - .expected_token_ids = {4}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {3}, - }, - // Test 42: No suffix tokens in the vocabulary. Vocab has the suffix - // indicator. Input is ###. - { - .vocab = {"a", "abc", "de", "abcdefghi", "", "##"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "###", - .expected_tokens = {""}, - .expected_token_ids = {4}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {3}, - }, - // Test 43: There is only one suffix tokens "###" in the vocabulary. - // Input is ###. - { - .vocab = {"a", "abc", "de", "abcdefghi", "", "###"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "###", - .expected_tokens = {"###"}, - .expected_token_ids = {5}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {3}, - }, - - // Test suite 5, No prefix tokens in the vocabulary. - // Test 44: No prefix tokens in the vocabulary. Input is a prefix token. - { - .vocab = {"##a", "##abc", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abc", - .expected_tokens = {""}, - .expected_token_ids = {2}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {3}, - }, - // Test 45: No prefix tokens in the vocabulary. Input is a suffix token. - { - .vocab = {"##a", "##abc", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##abc", - .expected_tokens = {"##abc"}, - .expected_token_ids = {1}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {5}, - }, - - // Test suite 6, more tests. - // Test 46: Input is empty. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "", - .expected_tokens = {}, - .expected_token_ids = {}, - .expected_token_start_offsets = {}, - .expected_token_end_offsets = {}, - }, - // Test 47: Normal input. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "unwanted", - .expected_tokens = {"un", "##want", "##ed"}, - .expected_token_ids = {7, 4, 5}, - .expected_token_start_offsets = {0, 2, 6}, - .expected_token_end_offsets = {2, 6, 8}, - }, - // Test 48: Unseen character. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "unwantedX", - .expected_tokens = {""}, - .expected_token_ids = {1}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {9}, - }, - - // Test suite 7. Testing on long inputs (kMaxInputCharPerWord = 100). The - // word length below means the number of utf-8 bytes. - // Test 49: Word length = 99 (i.e., kMaxInputCharPerWord-1). - { - .vocab = {"", "0123456789", "##0123456789", "##012345678"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "01234567890123456789012345678901234567890123456789012345678" - "9012345678901234567890123456789012345678", - .expected_tokens = {"0123456789", "##0123456789", "##0123456789", - "##0123456789", "##0123456789", "##0123456789", - "##0123456789", "##0123456789", "##0123456789", - "##012345678"}, - .expected_token_ids = {1, 2, 2, 2, 2, 2, 2, 2, 2, 3}, - .expected_token_start_offsets = {0, 10, 20, 30, 40, 50, 60, 70, 80, - 90}, - .expected_token_end_offsets = {10, 20, 30, 40, 50, 60, 70, 80, 90, - 99}, - }, - // Test 50: Word length = 100 (i.e., kMaxInputCharPerWord). Contains a - // multi-bytes Unicode char. - { - .vocab = {"", "0123456789", "##0123456789", "##01234567", - /*U+05C3*/ "##\xD7\x83", "##a"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "01234567890123456789012345678901234567890123456789012345678" - "901234567890123456789012345678901234567\xD7\x83", - .expected_tokens = {"0123456789", "##0123456789", "##0123456789", - "##0123456789", "##0123456789", "##0123456789", - "##0123456789", "##0123456789", "##0123456789", - "##01234567", "##\xD7\x83"}, - .expected_token_ids = {1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4}, - .expected_token_start_offsets = {0, 10, 20, 30, 40, 50, 60, 70, 80, - 90, 98}, - .expected_token_end_offsets = {10, 20, 30, 40, 50, 60, 70, 80, 90, 98, - 100}, - }, - // Test 51: Word length = 101 (i.e., kMaxInputCharPerWord+1). Contains a - // multi-bytes Unicode char. - { - .vocab = {"", "0123456789", "##0123456789", "##012345678", - /*U+05C3*/ "##\xD7\x83", "##a"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "01234567890123456789012345678901234567890123456789012345678" - "9012345678901234567890123456789012345678\xD7\x83", - .expected_tokens = {""}, - .expected_token_ids = {0}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {101}, - }, - // Test 52: Word length = 101 (i.e., kMaxInputCharPerWord+1). - { - .vocab = {"", "0123456789", "##0123456789", "##012345678", - "##a"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "01234567890123456789012345678901234567890123456789012345678" - "90123456789012345678901234567890123456789a", - .expected_tokens = {""}, - .expected_token_ids = {0}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {101}, - }, - // Test 53: Word length = 99 (i.e., kMaxInputCharPerWord-1). The word is - // not tokenizable. - { - .vocab = {"", "0123456789", "##0123456789", - "##012345678\xe2\x80\x8B", "##\xe2\x80\x8B"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "01234567890123456789012345678901234567890123456789012345678" - "9012345678901234567890123456789012345678", - .expected_tokens = {""}, - .expected_token_ids = {0}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {99}, - }, - - // Test suite 8. Normal vocab and inputs. - // Test 54. - { - .vocab = {"", "play", "see", "##ing", "##ed", "##es", "##ly", - "##on", "##s", "##able"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "play", - .expected_tokens = {"play"}, - .expected_token_ids = {1}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {4}, - }, - // Test 55. - { - .vocab = {"", "play", "see", "##ing", "##ed", "##es", "##ly", - "##on", "##s", "##able"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "playing", - .expected_tokens = {"play", "##ing"}, - .expected_token_ids = {1, 3}, - .expected_token_start_offsets = {0, 4}, - .expected_token_end_offsets = {4, 7}, - }, - // Test 56. - { - .vocab = {"", "play", "see", "##ing", "##ed", "##es", "##ly", - "##on", "##s", "##able"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "sees", - .expected_tokens = {"see", "##s"}, - .expected_token_ids = {2, 8}, - .expected_token_start_offsets = {0, 3}, - .expected_token_end_offsets = {3, 4}, - }, - // Test 57. - { - .vocab = {"", "play", "see", "##ing", "##ed", "##es", "##ly", - "##on", "##s", "##able", "u", "un", "##de", "##deni"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "undeniable", - .expected_tokens = {"un", "##deni", "##able"}, - .expected_token_ids = {11, 13, 9}, - .expected_token_start_offsets = {0, 2, 6}, - .expected_token_end_offsets = {2, 6, 10}, - }, - // Test 58. - { - .vocab = {"", "play", "see", "##ing", "##ed", "##es", "##ly", - "##on", "##s", "##able", "u", "un", "##de", "##deni", - "undeniable"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "undeniable", - .expected_tokens = {"undeniable"}, - .expected_token_ids = {14}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {10}, - }, - // Test 59. - { - .vocab = {"", "s", "su", "super", "##per", "##ca", - "##cali", "##f", "##fra", "##g", "##gil", "##i", - "##is", "##istic", "##e", "##ex", "##pi", "##pia", - "##li", "##lido", "##ci", "##cious", "##ous"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "supercalifragilisticexpialidocious", - .expected_tokens = {"super", "##cali", "##fra", "##gil", "##istic", - "##ex", "##pia", "##lido", "##cious"}, - .expected_token_ids = {3, 6, 8, 10, 13, 15, 17, 19, 21}, - .expected_token_start_offsets = {0, 5, 9, 12, 15, 20, 22, 25, 29}, - .expected_token_end_offsets = {5, 9, 12, 15, 20, 22, 25, 29, 34}, - }, - - // Test suite 9. Different unk_tokens. - // Test 60: Basic with a different unk_token. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "[unk]"}, - .unk_token = "[unk]", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcdefghz", - .expected_tokens = {"abc", "##de", "##f", "##ghz"}, - .expected_token_ids = {1, 3, 6, 7}, - .expected_token_start_offsets = {0, 3, 5, 6}, - .expected_token_end_offsets = {3, 5, 6, 9}, - }, - // Test 61: Untokenizable with a different unk_token. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "[unk]"}, - .unk_token = "[unk]", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcdefghzX", - .expected_tokens = {"[unk]"}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {10}, - }, - - // Test suite 10. Input is the unk_token. - // Test 62: Input is the unk_token. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "[unk]"}, - .unk_token = "[unk]", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "[unk]", - .expected_tokens = {"[unk]"}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {5}, - }, - - // Test suite 11. Input is the suffix indicator itself. - // Test 63: Suffix indicator is "##" and is tokenizable. - { - .vocab = {"#", "###", "a", "abc", "abcdefghi", "##de", "##defgxy", - "##deh", "##f", "##ghz", "[unk]"}, - .unk_token = "[unk]", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##", - .expected_tokens = {"#", "###"}, - .expected_token_ids = {0, 1}, - .expected_token_start_offsets = {0, 1}, - .expected_token_end_offsets = {1, 2}, - }, - // Test 64: Suffix indicator is "##" but not tokenizable. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "[unk]"}, - .unk_token = "[unk]", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##", - .expected_tokens = {"[unk]"}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {2}, - }, - // Test 65: Suffix indicator is "##" and "##" is in the vocabulary. - { - .vocab = {"#", "###", "##", "a", "abc", "abcdefghi", "##de", - "##defgxy", "##deh", "##f", "##ghz", "[unk]"}, - .unk_token = "[unk]", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##", - .expected_tokens = {"##"}, - .expected_token_ids = {2}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {2}, - }, - // Test 66: Suffix indicator is "###" and is tokenizable. - { - .vocab = {"#", "####", "[unk]"}, - .unk_token = "[unk]", - .suffix_indicator = "###", - .max_bytes_per_token = 100, - .input = "###", - .expected_tokens = {"#", "####", "####"}, - .expected_token_ids = {0, 1, 1}, - .expected_token_start_offsets = {0, 1, 2}, - .expected_token_end_offsets = {1, 2, 3}, - }, - // Test 67: Suffix indicator is "###" and is tokenizable. A different - // vocab. - { - .vocab = {"#", "####", "##", "[unk]"}, - .unk_token = "[unk]", - .suffix_indicator = "###", - .max_bytes_per_token = 100, - .input = "###", - .expected_tokens = {"##", "####"}, - .expected_token_ids = {2, 1}, - .expected_token_start_offsets = {0, 2}, - .expected_token_end_offsets = {2, 3}, - }, - - // Test suite 12, different suffix indicators. - // Test 68: A different suffix indicator. - { - .vocab = {"a", "abc", "abcdefghi", "de", "defgxy", - "deh", "f", "ghz", ""}, - .unk_token = "", - .suffix_indicator = "", - .max_bytes_per_token = 100, - .input = "abcdefghz", - .expected_tokens = {"abc", "de", "f", "ghz"}, - .expected_token_ids = {1, 3, 6, 7}, - .expected_token_start_offsets = {0, 3, 5, 6}, - .expected_token_end_offsets = {3, 5, 6, 9}, - }, - // Test 69: The suffix indicator is empty. - { - .vocab = {"a", "abc", "abcdefghi", "de", "defgxy", "deh", "f", "ghz", - ""}, - .unk_token = "", - .suffix_indicator = "", - .max_bytes_per_token = 100, - .input = "abcdefghz", - .expected_tokens = {"abc", "de", "f", "ghz"}, - .expected_token_ids = {1, 3, 6, 7}, - .expected_token_start_offsets = {0, 3, 5, 6}, - .expected_token_end_offsets = {3, 5, 6, 9}, - }, - // Test 70: The suffix indicator is empty. Input is empty. - { - .vocab = {"a", "abc", "abcdefghi", "de", "defgxy", "deh", "f", "ghz", - ""}, - .unk_token = "", - .suffix_indicator = "", - .max_bytes_per_token = 100, - .input = "", - .expected_tokens = {}, - .expected_token_ids = {}, - .expected_token_start_offsets = {}, - .expected_token_end_offsets = {}, - }, - - // Test suite 13, multi-bytes chars in vocab and input. - // The following codepoints and their utf-8 encodings are used here: - // * U+03B1 (Greek Small Letter Alpha): "\xCE\xB1" - // * U+03B2 (Greek Small Letter Beta): "\xCE\xB2" - // * U+2EDA (Cjk Radical C-Simplified Leaf): b'\xE2\xBB\x9A' - // * U+2EDB (Cjk Radical C-Simplified Wind): b'\xE2\xBB\x9B' - // Test 71: multi-bytes chars in the vocab. - { - .vocab = {"", "abc", "a", "##bc", "a\xCE\xB1\xCE\xB2", - "\xCE\xB1", "##\xCE\xB1", "##\xCE\xB2", "\xE2\xBB\x9A"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abc", - .expected_tokens = {"abc"}, - .expected_token_ids = {1}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {3}, - }, - // Test 72: input contains 2-bytes chars. - { - .vocab = {"", "abc", "a", "##bc", "a\xCE\xB1\xCE\xB2", - "\xCE\xB1", "##\xCE\xB1", "##\xCE\xB2", "\xE2\xBB\x9A"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "a\xCE\xB1\xCE\xB2\xCE\xB1\xCE\xB2", - .expected_tokens = {"a\xCE\xB1\xCE\xB2", "##\xCE\xB1", "##\xCE\xB2"}, - .expected_token_ids = {4, 6, 7}, - .expected_token_start_offsets = {0, 5, 7}, - .expected_token_end_offsets = {5, 7, 9}, - }, - // Test 73: input contains 3-bytes chars. - { - .vocab = {"", "abc", "a", "##bc", "a\xCE\xB1\xCE\xB2", - "\xCE\xB1", "##\xCE\xB1", "##\xCE\xB2", "\xE2\xBB\x9A"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "\xE2\xBB\x9A" - "bc\xCE\xB1", - .expected_tokens = {"\xE2\xBB\x9A", "##bc", "##\xCE\xB1"}, - .expected_token_ids = {8, 3, 6}, - .expected_token_start_offsets = {0, 3, 5}, - .expected_token_end_offsets = {3, 5, 7}, - }, - // Test 74: input contains unseen multi-bytes chars. - { - .vocab = {"", "abc", "a", "##bc", "a\xCE\xB1\xCE\xB2", - "\xCE\xB1", "##\xCE\xB1", "##\xCE\xB2", "\xE2\xBB\x9A"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "\xE2\xBB\x9B", - .expected_tokens = {""}, - .expected_token_ids = {0}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {3}, - }, - }; - return v; -} - -using TestTokenizeSingleWord = testing::TestWithParam; - -TEST_P(TestTokenizeSingleWord, Test) { - const Spec& spec = GetParam(); - ASSERT_OK_AND_ASSIGN( - std::string flatbuffer, - BuildModelAndExportToFlatBuffer(spec.vocab, spec.max_bytes_per_token, - spec.suffix_indicator, spec.unk_token, - /*no_pretokenization=*/true)); - ASSERT_OK_AND_ASSIGN(auto tokenizer, - FastWordpieceTokenizer::Create(flatbuffer.data())); - - std::vector output_tokens; - std::vector output_ids; - std::vector output_begin_offsets; - std::vector output_end_offsets; - tokenizer.Tokenize(spec.input, &output_tokens, &output_ids, - &output_begin_offsets, &output_end_offsets); - EXPECT_THAT(output_tokens, spec.expected_tokens); - EXPECT_THAT(output_ids, spec.expected_token_ids); - EXPECT_THAT(output_begin_offsets, spec.expected_token_start_offsets); - EXPECT_THAT(output_end_offsets, spec.expected_token_end_offsets); -} - -TEST_P(TestTokenizeSingleWord, TestNoOutputPieces) { - const Spec& spec = GetParam(); - ASSERT_OK_AND_ASSIGN( - std::string flatbuffer, - BuildModelAndExportToFlatBuffer(spec.vocab, spec.max_bytes_per_token, - spec.suffix_indicator, spec.unk_token, - true /* no_pretokenization */)); - ASSERT_OK_AND_ASSIGN(auto tokenizer, - FastWordpieceTokenizer::Create(flatbuffer.data())); - - std::vector output_ids; - std::vector output_begin_offsets; - std::vector output_end_offsets; - tokenizer.Tokenize(spec.input, &output_ids, &output_begin_offsets, - &output_end_offsets); - EXPECT_THAT(output_ids, spec.expected_token_ids); - EXPECT_THAT(output_begin_offsets, spec.expected_token_start_offsets); - EXPECT_THAT(output_end_offsets, spec.expected_token_end_offsets); -} - -TEST_P(TestTokenizeSingleWord, TestNoOutputPiecesOnlyOutputIds) { - const Spec& spec = GetParam(); - ASSERT_OK_AND_ASSIGN( - std::string flatbuffer, - BuildModelAndExportToFlatBuffer(spec.vocab, spec.max_bytes_per_token, - spec.suffix_indicator, spec.unk_token, - true /* no_pretokenization */)); - ASSERT_OK_AND_ASSIGN(auto tokenizer, - FastWordpieceTokenizer::Create(flatbuffer.data())); - - std::vector output_ids; - tokenizer.Tokenize(spec.input, &output_ids); - EXPECT_THAT(output_ids, spec.expected_token_ids); -} - -TEST_P(TestTokenizeSingleWord, TestNoOutputPiecesWithPositiveSentenceOffsets) { - const Spec& spec = GetParam(); - const int offset_in_sentence = 123; - ASSERT_OK_AND_ASSIGN( - std::string flatbuffer, - BuildModelAndExportToFlatBuffer(spec.vocab, spec.max_bytes_per_token, - spec.suffix_indicator, spec.unk_token, - true /* no_pretokenization */)); - ASSERT_OK_AND_ASSIGN(auto tokenizer, - FastWordpieceTokenizer::Create(flatbuffer.data())); - - std::vector output_ids; - std::vector output_begin_offsets; - std::vector output_end_offsets; - std::vector expected_token_start_offsets( - spec.expected_token_start_offsets); - std::vector expected_token_end_offsets(spec.expected_token_end_offsets); - - for (int& offset : expected_token_start_offsets) { - offset += offset_in_sentence; - } - for (int& offset : expected_token_end_offsets) { - offset += offset_in_sentence; - } - - tokenizer.Tokenize(spec.input, &output_ids, &output_begin_offsets, - &output_end_offsets, - /*input_word_offset_in_text=*/offset_in_sentence); - EXPECT_THAT(output_begin_offsets, expected_token_start_offsets); - EXPECT_THAT(output_end_offsets, expected_token_end_offsets); -} - -INSTANTIATE_TEST_SUITE_P( - FastWordpieceTokenizerParameterizedTest, TestTokenizeSingleWord, - testing::ValuesIn(GetTestSpecsForTokenizeSingleWord())); - -// Test End-to-end FastWordPieceTokenization for tokenizing general texts. -const std::vector& GetTestSpecsForTokenizeText() { - static const std::vector& v = *new std::vector{ - // Test suite 1. End-to-end test including whitespace tokenization. - // Test 0: Input is empty. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "", - .expected_tokens = {}, - .expected_token_ids = {}, - .expected_token_start_offsets = {}, - .expected_token_end_offsets = {}, - }, - // Test 1: Input has only spaces. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " \t ", - .expected_tokens = {}, - .expected_token_ids = {}, - .expected_token_start_offsets = {}, - .expected_token_end_offsets = {}, - }, - // Test 2: Input is a single word. Result is OK. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcdef", - .expected_tokens = {"abc", "##de", "##f"}, - .expected_token_ids = {1, 3, 6}, - .expected_token_start_offsets = {0, 3, 5}, - .expected_token_end_offsets = {3, 5, 6}, - }, - // Test 3: Input is a single word. Result is . - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcd", - .expected_tokens = {""}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {4}, - }, - // Test 4: Input contains multiple words, with several whitespaces in the - // middle. Result is OK. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcdef \t\t \tabcf", - .expected_tokens = {"abc", "##de", "##f", "abc", "##f"}, - .expected_token_ids = {1, 3, 6, 1, 6}, - .expected_token_start_offsets = {0, 3, 5, 11, 14}, - .expected_token_end_offsets = {3, 5, 6, 14, 15}, - }, - // Test 5: Input has multiple words, with leading and trailing spaces. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "\tabcdef abcf ", - .expected_tokens = {"abc", "##de", "##f", "abc", "##f"}, - .expected_token_ids = {1, 3, 6, 1, 6}, - .expected_token_start_offsets = {1, 4, 6, 9, 12}, - .expected_token_end_offsets = {4, 6, 7, 12, 13}, - }, - // Test 6: Input contains suffix indicator as words. Suffix indicator is - // in vocab. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "", "##"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "## abcde ## ##a", - .expected_tokens = {"", "", "abc", "##de", "", "", - "", "", "a"}, - .expected_token_ids = {8, 8, 1, 3, 8, 8, 8, 8, 0}, - .expected_token_start_offsets = {0, 1, 3, 6, 9, 10, 13, 14, 15}, - .expected_token_end_offsets = {1, 2, 6, 8, 10, 11, 14, 15, 16}, - }, - // Test 7: Input contains suffix indicator as words. Suffix indicator is - // in vocab. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", "", "##"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "## abcde ## ##a ##f", - .expected_tokens = {"", "", "abc", "##de", "", "", - "", "", "a", "", "", ""}, - .expected_token_ids = {8, 8, 1, 3, 8, 8, 8, 8, 0, 8, 8, 8}, - .expected_token_start_offsets = {0, 1, 3, 6, 9, 10, 13, 14, 15, 17, - 18, 19}, - .expected_token_end_offsets = {1, 2, 6, 8, 10, 11, 14, 15, 16, 18, 19, - 20}, - }, - // Test 8: Input contains suffix indicator as words. Suffix indicator is - // not in vocab. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##", - .expected_tokens = {"", ""}, - .expected_token_ids = {8, 8}, - .expected_token_start_offsets = {0, 1}, - .expected_token_end_offsets = {1, 2}, - }, - // Test 9: Input contains unseen character words. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " a \tabcdeX \rabcdefghz abcdeXfghz Xabc abcd", - .expected_tokens = {"a", "", "abc", "##de", "##f", "##ghz", - "", "", ""}, - .expected_token_ids = {0, 8, 1, 3, 6, 7, 8, 8, 8}, - .expected_token_start_offsets = {1, 4, 12, 15, 17, 18, 22, 33, 38}, - .expected_token_end_offsets = {2, 10, 15, 17, 18, 21, 32, 37, 42}, - }, - // Test 10: Input contains untokenizable words. No spaces before or after. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcdefgx", - .expected_tokens = {""}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {8}, - }, - // Test 11: Input contains untokenizable words. One space before. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " abcdefgx", - .expected_tokens = {""}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {1}, - .expected_token_end_offsets = {9}, - }, - // Test 12: Input contains untokenizable words. One space after. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcdefgx ", - .expected_tokens = {""}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {0}, - .expected_token_end_offsets = {8}, - }, - // Test 13: Input has untokenizable words. One space before and after. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " abcdefgx ", - .expected_tokens = {""}, - .expected_token_ids = {8}, - .expected_token_start_offsets = {1}, - .expected_token_end_offsets = {9}, - }, - // Test 14: Input contains mix words with unseen characters. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " a \tabcdeX \rabcdefghz abcdeXfghz Xabc", - .expected_tokens = {"a", "", "abc", "##de", "##f", "##ghz", - "", ""}, - .expected_token_ids = {0, 8, 1, 3, 6, 7, 8, 8}, - .expected_token_start_offsets = {1, 4, 12, 15, 17, 18, 22, 33}, - .expected_token_end_offsets = {2, 10, 15, 17, 18, 21, 32, 37}, - }, - // Test 15: Another basic test. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "unwanted running", - .expected_tokens = {"un", "##want", "##ed", "runn", "##ing"}, - .expected_token_ids = {7, 4, 5, 8, 9}, - .expected_token_start_offsets = {0, 2, 6, 9, 13}, - .expected_token_end_offsets = {2, 6, 8, 13, 16}, - }, - // Test 16: Input has unseen characters. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "unwantedX running", - .expected_tokens = {"", "runn", "##ing"}, - .expected_token_ids = {0, 8, 9}, - .expected_token_start_offsets = {0, 10, 14}, - .expected_token_end_offsets = {9, 14, 17}, - }, - // Test 17: Input contains mix words with untokenizable words. - { - .vocab = {"a", "abc", "abcdefghi", "##de", "##defgxy", "##deh", "##f", - "##ghz", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " a \tabcdeX \rabcdefghz abcdeXfghz ab", - .expected_tokens = {"a", "", "abc", "##de", "##f", "##ghz", - "", ""}, - .expected_token_ids = {0, 8, 1, 3, 6, 7, 8, 8}, - .expected_token_start_offsets = {1, 4, 12, 15, 17, 18, 22, 33}, - .expected_token_end_offsets = {2, 10, 15, 17, 18, 21, 32, 35}, - }, - // Test 18: Input and vocab contains Unicode tokens. The Trie matching - // loop would stop at matching a partial word. - { - .vocab = {"\xE2\x82\xAC", "a", "abc", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " \xE2\x82\xAD abc", - .expected_tokens = {"", "abc"}, - .expected_token_ids = {3, 2}, - .expected_token_start_offsets = {1, 5}, - .expected_token_end_offsets = {4, 8}, - }, - // Test 19: Contains suffix indicator as a word. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "...", "#", "###"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##", - .expected_tokens = {"#", "#"}, - .expected_token_ids = {13, 13}, - .expected_token_start_offsets = {0, 1}, - .expected_token_end_offsets = {1, 2}, - }, - // Test 20: unknown words. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " X wantXwanted. \t ", - .expected_tokens = {"", "", "."}, - .expected_token_ids = {1, 1, 10}, - .expected_token_start_offsets = {1, 3, 14}, - .expected_token_end_offsets = {2, 14, 15}, - }, - // Test 21: After the loop, the next character is whitespace. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wanted. \t wa..", - .expected_tokens = {"want", "##ed", ".", "wa", ".", "."}, - .expected_token_ids = {3, 5, 10, 6, 10, 10}, - .expected_token_start_offsets = {2, 6, 8, 13, 15, 16}, - .expected_token_end_offsets = {6, 8, 9, 15, 16, 17}, - }, - // Test 22: After the loop, the next character is not a whitespace. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wanted.x \t wa..", - .expected_tokens = {"want", "##ed", ".", "", "wa", ".", "."}, - .expected_token_ids = {3, 5, 10, 1, 6, 10, 10}, - .expected_token_start_offsets = {2, 6, 8, 9, 14, 16, 17}, - .expected_token_end_offsets = {6, 8, 9, 10, 16, 17, 18}, - }, - // Test 23: After the loop, the next character is not a whitespace. And a - // trailing space. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wanted.x \t wa.. \n", - .expected_tokens = {"want", "##ed", ".", "", "wa", ".", "."}, - .expected_token_ids = {3, 5, 10, 1, 6, 10, 10}, - .expected_token_start_offsets = {2, 6, 8, 9, 14, 16, 17}, - .expected_token_end_offsets = {6, 8, 9, 10, 16, 17, 18}, - }, - // Test 24: After the loop, it's in the middle of a whitespace. The - // previous is tokenizable. - { - .vocab = {"", "want", "##want", "##ed", "wa", ".", "##.", "...", - "##\xc2\xa1"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wanted\xc2\xa0\t wa", - .expected_tokens = {"want", "##ed", "wa"}, - .expected_token_ids = {1, 3, 4}, - .expected_token_start_offsets = {2, 6, 12}, - .expected_token_end_offsets = {6, 8, 14}, - }, - // Test 25: After the loop, it's in the middle of a whitespace. The - // previous is tokenizable (a punctuation). - { - .vocab = {"", "want", "##want", "##ed", "wa", ".", "##.", "...", - "\xc2\xa1", "##\xc2\xa1"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wanted.\xc2\xa0\t wa", - .expected_tokens = {"want", "##ed", ".", "wa"}, - .expected_token_ids = {1, 3, 5, 4}, - .expected_token_start_offsets = {2, 6, 8, 13}, - .expected_token_end_offsets = {6, 8, 9, 15}, - }, - // Test 26: After the loop, it's in the middle of a whitespace. The - // previous is untokenizable. - { - .vocab = {"", "want", "##want", "##ed", "wa", ".", "##.", "...", - "##e\xC2\xA1", "##\xC2\xA1"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wante\xc2\xa0\t wa", - .expected_tokens = {"", "wa"}, - .expected_token_ids = {0, 4}, - .expected_token_start_offsets = {2, 11}, - .expected_token_end_offsets = {7, 13}, - }, - - // Test suite 2. End-to-end test including whitespace tokenization and - // split on punctuation. - // Test 27. Basic case 1. - { - .vocab = - { - "", "don", "##'", "##t", "tread", "##ness", - "hel", "##lo", "there", "my", "na", "##me", - "is", "ter", "##ry", "what", "##cha", "##ma", - "##call", "##it?", "you", "said", - }, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "hello there my name is terry", - .expected_tokens = {"hel", "##lo", "there", "my", "na", "##me", "is", - "ter", "##ry"}, - .expected_token_ids = {6, 7, 8, 9, 10, 11, 12, 13, 14}, - .expected_token_start_offsets = {0, 3, 6, 12, 15, 17, 20, 23, 26}, - .expected_token_end_offsets = {3, 5, 11, 14, 17, 19, 22, 26, 28}, - }, - // Test 28. Basic case 2. - { - .vocab = - { - "", "don", "##'", "##t", "tread", "##ness", - "hel", "##lo", "there", "my", "na", "##me", - "is", "ter", "##ry", "what", "##cha", "##ma", - "##call", "##it?", "you", "said", - }, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "whatchamacallit? you said", - .expected_tokens = {"", "", "you", "said"}, - .expected_token_ids = {0, 0, 20, 21}, - .expected_token_start_offsets = {0, 15, 17, 21}, - .expected_token_end_offsets = {15, 16, 20, 25}, - }, - // Test 29. Basic case 3. Punctuation is an independant word in the vocab. - { - .vocab = - { - "", "don", "##'", "##t", "tread", "##ness", - "hel", "##lo", "there", "my", "na", "##me", - "is", "ter", "##ry", "what", "##cha", "##ma", - "##call", "##it?", "you", "said", "##it", "?", - }, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "whatchamacallit? you said", - .expected_tokens = {"what", "##cha", "##ma", "##call", "##it", "?", - "you", "said"}, - .expected_token_ids = {15, 16, 17, 18, 22, 23, 20, 21}, - .expected_token_start_offsets = {0, 4, 7, 9, 13, 15, 17, 21}, - .expected_token_end_offsets = {4, 7, 9, 13, 15, 16, 20, 25}, - }, - // Test 30. Basic case 4 with untokenizable words. - { - .vocab = - { - "", "don", "'", "t", "tread", "##ness", - "hel", "##lo", "there", "my", "na", "##me", - "is", "ter", "##ry", "what", "##cha", "##ma", - "##call", "##it?", "you", "said", - }, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "don't tread cantfindme treadcantfindme", - .expected_tokens = {"don", "'", "t", "tread", "", ""}, - .expected_token_ids = {1, 2, 3, 4, 0, 0}, - .expected_token_start_offsets = {0, 3, 4, 6, 12, 23}, - .expected_token_end_offsets = {3, 4, 5, 11, 22, 38}, - }, - // Test 31: Basic case 5. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "unwanted.", - .expected_tokens = {"un", "##want", "##ed", "."}, - .expected_token_ids = {7, 4, 5, 10}, - .expected_token_start_offsets = {0, 2, 6, 8}, - .expected_token_end_offsets = {2, 6, 8, 9}, - }, - // Test 32: Basic case 6. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " want.wanted. \t ", - .expected_tokens = {"want", ".", "want", "##ed", "."}, - .expected_token_ids = {3, 10, 3, 5, 10}, - .expected_token_start_offsets = {2, 6, 7, 11, 13}, - .expected_token_end_offsets = {6, 7, 11, 13, 14}, - }, - // Test 33: Basic with unseen characters (as a single word). - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " X want.wanted. \t ", - .expected_tokens = {"", "want", ".", "want", "##ed", "."}, - .expected_token_ids = {1, 3, 10, 3, 5, 10}, - .expected_token_start_offsets = {1, 3, 7, 8, 12, 14}, - .expected_token_end_offsets = {2, 7, 8, 12, 14, 15}, - }, - // Test 34: Basic with unseen characters (in a word before a punctuation). - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " X wantX.wanted. \t ", - .expected_tokens = {"", "", ".", "want", "##ed", "."}, - .expected_token_ids = {1, 1, 10, 3, 5, 10}, - .expected_token_start_offsets = {1, 3, 8, 9, 13, 15}, - .expected_token_end_offsets = {2, 8, 9, 13, 15, 16}, - }, - // Test 35: Basic with unseen characters (in the middle of a word). - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " X wantXwanted. \t ", - .expected_tokens = {"", "", "."}, - .expected_token_ids = {1, 1, 10}, - .expected_token_start_offsets = {1, 3, 14}, - .expected_token_end_offsets = {2, 14, 15}, - }, - // Test 36: Basic with unseen characters and a leading period. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " X .wantXwanted. \t ", - .expected_tokens = {"", ".", "", "."}, - .expected_token_ids = {1, 10, 1, 10}, - .expected_token_start_offsets = {1, 3, 4, 15}, - .expected_token_end_offsets = {2, 4, 15, 16}, - }, - // Test 37: Contains ellipsis (as "....."). - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wanted. \t wa.....", - .expected_tokens = {"want", "##ed", ".", "wa", ".", ".", ".", ".", - "."}, - .expected_token_ids = {3, 5, 10, 6, 10, 10, 10, 10, 10}, - .expected_token_start_offsets = {2, 6, 8, 13, 15, 16, 17, 18, 19}, - .expected_token_end_offsets = {6, 8, 9, 15, 16, 17, 18, 19, 20}, - }, - // Test 38: After the loop, the next character is an unknown punctuation; - // the previous can be tokenized. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wanted, \t wa", - .expected_tokens = {"want", "##ed", "", "wa"}, - .expected_token_ids = {3, 5, 1, 6}, - .expected_token_start_offsets = {2, 6, 8, 13}, - .expected_token_end_offsets = {6, 8, 9, 15}, - }, - // Test 39: After the loop, the next character is an unknown punctuation; - // the previous can be tokenized. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wanted., \t wa", - .expected_tokens = {"want", "##ed", ".", "", "wa"}, - .expected_token_ids = {3, 5, 10, 1, 6}, - .expected_token_start_offsets = {2, 6, 8, 9, 14}, - .expected_token_end_offsets = {6, 8, 9, 10, 16}, - }, - // Test 40: After the loop, the next character is an unknown punctuation; - // the previous is empty. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " , wanted, \t wa", - .expected_tokens = {"", "want", "##ed", "", "wa"}, - .expected_token_ids = {1, 3, 5, 1, 6}, - .expected_token_start_offsets = {1, 3, 7, 9, 14}, - .expected_token_end_offsets = {2, 7, 9, 10, 16}, - }, - // Test 41: After the loop, the next character is an unknown punctuation; - // the previous can not be tokenized. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wante, \t wa", - .expected_tokens = {"", "", "wa"}, - .expected_token_ids = {1, 1, 6}, - .expected_token_start_offsets = {2, 7, 12}, - .expected_token_end_offsets = {7, 8, 14}, - }, - // Test 42: After the loop, in the middle of an unknown punctuation. - // Previous is tokenizable. - { - .vocab = {"", "want", "##want", "##ed", "wa", ".", "##.", "...", - /*U+05C3*/ "\xD7\x83"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wanted\xd7\x86xyz \t wa", - .expected_tokens = {"want", "##ed", "", "", "wa"}, - .expected_token_ids = {1, 3, 0, 0, 4}, - .expected_token_start_offsets = {2, 6, 8, 10, 17}, - .expected_token_end_offsets = {6, 8, 10, 13, 19}, - }, - // Test 43: After the loop, in the middle of an unknown punctuation. - // Previous is tokenizable. - { - .vocab = {"", "want", "##want", "##ed", "wa", ".", "##.", "...", - /*U+05C3*/ "\xD7\x83"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wanted.\xd7\x86xyz \t wa", - .expected_tokens = {"want", "##ed", ".", "", "", "wa"}, - .expected_token_ids = {1, 3, 5, 0, 0, 4}, - .expected_token_start_offsets = {2, 6, 8, 9, 11, 18}, - .expected_token_end_offsets = {6, 8, 9, 11, 14, 20}, - }, - // Test 44: After the loop, in the middle of an unknown punctuation. - // Previous is not tokenizable. - { - .vocab = {"", "want", "##want", "##ed", "wa", ".", "##.", "...", - /*U+05C3*/ "##e\xD7\x83", - /*U+05C3*/ "\xD7\x83"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wante\xd7\x86xyz \t wa", - .expected_tokens = {"", "", "", "wa"}, - .expected_token_ids = {0, 0, 0, 4}, - .expected_token_start_offsets = {2, 7, 9, 16}, - .expected_token_end_offsets = {7, 9, 12, 18}, - }, - // Test 45: Fails to match the first character in the beginning. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "xyz \t wa", - .expected_tokens = {"", "wa"}, - .expected_token_ids = {1, 6}, - .expected_token_start_offsets = {0, 7}, - .expected_token_end_offsets = {3, 9}, - }, - // Test 46: After the loop, the next character is not a whitespace nor - // punctuation. Trie fails to recognize the first character. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wanted.xyz \t wa", - .expected_tokens = {"want", "##ed", ".", "", "wa"}, - .expected_token_ids = {3, 5, 10, 1, 6}, - .expected_token_start_offsets = {2, 6, 8, 9, 16}, - .expected_token_end_offsets = {6, 8, 9, 12, 18}, - }, - // Test 47: After the loop, the next character is not a whitespace nor - // punctuation. Previous is not tokenizable. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wantedxyz \t wa", - .expected_tokens = {"", "wa"}, - .expected_token_ids = {1, 6}, - .expected_token_start_offsets = {2, 15}, - .expected_token_end_offsets = {11, 17}, - }, - // Test 48: After the loop, the next character is not a whitespace nor - // punctuation. Previous is not tokenizable. - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = " wantexyz \t wa", - .expected_tokens = {"", "wa"}, - .expected_token_ids = {1, 6}, - .expected_token_start_offsets = {2, 14}, - .expected_token_end_offsets = {10, 16}, - }, - // Test 49: Unknown punctuation followed by unseen character. - { - .vocab = {"", "want", "##want", "##ed", "wa", ".", "##.", "...", - /*U+05C3*/ "##e\xD7\x83", - /*U+05C3*/ "\xD7\x83"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "wanted\xd7\x86xyz", - .expected_tokens = {"want", "##ed", "", ""}, - .expected_token_ids = {1, 3, 0, 0}, - .expected_token_start_offsets = {0, 4, 6, 8}, - .expected_token_end_offsets = {4, 6, 8, 11}, - }, - // Test 50: Ellipsis is mapped to ""s when "." is not in vocab. - { - .vocab = {"", "want", "##want", "##ed", "wa", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "wanted...", - .expected_tokens = {"want", "##ed", "", "", ""}, - .expected_token_ids = {1, 3, 0, 0, 0}, - .expected_token_start_offsets = {0, 4, 6, 7, 8}, - .expected_token_end_offsets = {4, 6, 7, 8, 9}, - }, - - // Test suite 3. End-to-end test including whitespace and punctuation - // tokenization on max_bytes_per_token = 10. - // Test 51: Word length = 9 (i.e., max_bytes_per_token-1). - { - .vocab = {"", "01234", "##5678", "##56789", - /*U+05C3*/ "##\xD7\x83"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = " 012345678 ", - .expected_tokens = {"01234", "##5678"}, - .expected_token_ids = {1, 2}, - .expected_token_start_offsets = {2, 7}, - .expected_token_end_offsets = {7, 11}, - }, - // Test 52: Word length = 10 (i.e., max_bytes_per_token). - { - .vocab = {"", "01234", "##5678", "##56789", - /*U+05C3*/ "##\xD7\x83"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = " 0123456789 ", - .expected_tokens = {"01234", "##56789"}, - .expected_token_ids = {1, 3}, - .expected_token_start_offsets = {2, 7}, - .expected_token_end_offsets = {7, 12}, - }, - // Test 53: Word length = 9, followed by a multi-bytes Unicode punctuation - // char, which is a hebrew punctuation "sof pasquq". - { - .vocab = {"", "01234", "##5678", "##56789", - /*U+05C3*/ "##\xD7\x83"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = " 012345678\xD7\x83 ", - .expected_tokens = {"01234", "##5678", ""}, - .expected_token_ids = {1, 2, 0}, - .expected_token_start_offsets = {2, 7, 11}, - .expected_token_end_offsets = {7, 11, 13}, - }, - // Test 54: Word length = 11 (i.e., max_bytes_per_token+1). The 10th - // char is on Unicode boundary. - { - .vocab = {"", "01234", "##5678", "##56789", - /*U+05C3*/ "##\xD7\x83", "##a"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = " 0123456789a ", - .expected_tokens = {""}, - .expected_token_ids = {0}, - .expected_token_start_offsets = {2}, - .expected_token_end_offsets = {13}, - }, - // Test 55: Word length = 10 (i.e., max_bytes_per_token). The next char - // (\xe2\x80\x80) is a whitespace. - { - .vocab = {"", "01234", "##5678", "##56789", - /*U+05C3*/ "##\xD7\x83", "##a"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = " 0123456789\xe2\x80\x80 ", - .expected_tokens = {"01234", "##56789"}, - .expected_token_ids = {1, 3}, - .expected_token_start_offsets = {2, 7}, - .expected_token_end_offsets = {7, 12}, - }, - // Test 56: Word length = 9 (i.e., max_bytes_per_token-1). The next is - // a multi-byte whitespace. The 10th char is in the middle of the - // whitespace. - { - .vocab = {"", "01234", "##5678", "##56789", - /*U+05C3*/ "##\xD7\x83", "##a", "##\xe2\x80\x8B"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = " 012345678\xe2\x80\x80 ", - .expected_tokens = {"01234", "##5678"}, - .expected_token_ids = {1, 2}, - .expected_token_start_offsets = {2, 7}, - .expected_token_end_offsets = {7, 11}, - }, - // Test 57: Word length = 9 (i.e., max_bytes_per_token-1). The next is a - // multi-byte whitespace. The 10th char is in the middle of the - // whitespace. The word is not tokenizable. - { - .vocab = {"", "01234", "##56789", "##5678\xe2\x80\x8B", - /*U+05C3*/ "##\xD7\x83", "##a", "##\xe2\x80\x8B"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = " 012345678\xe2\x80\x80 ", - .expected_tokens = {""}, - .expected_token_ids = {0}, - .expected_token_start_offsets = {2}, - .expected_token_end_offsets = {11}, - }, - // Test 58: Word length = 9 (i.e., max_bytes_per_token-1) plus a - // trailing punctuation. - { - .vocab = {"", "01234", "##5678", "##56789", - /*U+05C3*/ "##\xD7\x83", "##a", "."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = " .012345678. ", - .expected_tokens = {".", "01234", "##5678", "."}, - .expected_token_ids = {6, 1, 2, 6}, - .expected_token_start_offsets = {2, 3, 8, 12}, - .expected_token_end_offsets = {3, 8, 12, 13}, - }, - // Test 59: Word length = 9 (i.e., max_bytes_per_token-1) plus a - // trailing punctuation, followed by more words. - { - .vocab = {"", "01234", "##5678", "##56789", - /*U+05C3*/ "\xD7\x83", "##a", ".", "...", "a"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = " .012345678.a ", - .expected_tokens = {".", "01234", "##5678", ".", "a"}, - .expected_token_ids = {6, 1, 2, 6, 8}, - .expected_token_start_offsets = {2, 3, 8, 12, 13}, - .expected_token_end_offsets = {3, 8, 12, 13, 14}, - }, - // Test 60: Word length = 10 (i.e., max_bytes_per_token) plus a - // trailing punctuation, and the word is tokenizable. - { - .vocab = {"", "01234", "##5678", "##56789", - /*U+05C3*/ "\xD7\x83", "##a", ".", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = " .0123456789. ", - .expected_tokens = {".", "01234", "##56789", "."}, - .expected_token_ids = {6, 1, 3, 6}, - .expected_token_start_offsets = {2, 3, 8, 13}, - .expected_token_end_offsets = {3, 8, 13, 14}, - }, - // Test 61: Word length = 10 (i.e., max_bytes_per_token) plus a - // trailing unknown punctuation, and the word is tokenizable. - { - .vocab = {"", "01234", "##5678", "##56789", "##a", ".", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = " .0123456789\xD7\x83 ", - .expected_tokens = {".", "01234", "##56789", ""}, - .expected_token_ids = {5, 1, 3, 0}, - .expected_token_start_offsets = {2, 3, 8, 13}, - .expected_token_end_offsets = {3, 8, 13, 15}, - }, - // Test 62: Word length = 11 (i.e., max_bytes_per_token+1). - { - .vocab = {"", "01234", "##5678", "##56789", - /*U+05C3*/ "\xD7\x83", "##a", ".", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = " .0123456789Z ", - .expected_tokens = {".", ""}, - .expected_token_ids = {6, 0}, - .expected_token_start_offsets = {2, 3}, - .expected_token_end_offsets = {3, 14}, - }, - // Test 63: Word length = 11 (i.e., max_bytes_per_token+1). - // The input would be tokenizable if `max_byte_per_token` is set to be - // greater or equal to `word_length`. - { - .vocab = {"", "0123456789", "##0123456789", "##012345678abc", - /*U+05C3*/ "\xD7\x83", "##a", ".", "..."}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = " .012345678a. ", - .expected_tokens = {".", "", "."}, - .expected_token_ids = {6, 0, 6}, - .expected_token_start_offsets = {2, 3, 13}, - .expected_token_end_offsets = {3, 13, 14}, - }, - // Test 64: Input is "". - { - .vocab = {"", "0123456789", "##0123456789", "##012345678abc", - /*U+05C3*/ "\xD7\x83", "##a", ".", "...", ">"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = ".", - .expected_tokens = {"", "", ">", "."}, - .expected_token_ids = {0, 0, 8, 6}, - .expected_token_start_offsets = {0, 1, 4, 5}, - .expected_token_end_offsets = {1, 4, 5, 6}, - }, - - // Test suite 4: Test different suffix indicators. - // Test 65: Suffix indicator is "##". Input contains "##". - { - .vocab = {"", "", "", "want", "##want", "##ed", "wa", - "un", "runn", "##ing", ".", "##.", "...", "#", "##", "###"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "## running", - .expected_tokens = {"#", "#", "runn", "##ing"}, - .expected_token_ids = {13, 13, 8, 9}, - .expected_token_start_offsets = {0, 1, 3, 7}, - .expected_token_end_offsets = {1, 2, 7, 10}, - }, - // Test 66: Test suffix indicator "". - { - .vocab = {"", "want", "want", "ed", "wa", "un", - "runn", "ing", "#", "."}, - .unk_token = "", - .suffix_indicator = "", - .max_bytes_per_token = 100, - .input = "## running. <", - .expected_tokens = {"#", "#", "runn", "ing", ".", ""}, - .expected_token_ids = {8, 8, 6, 7, 9, 0}, - .expected_token_start_offsets = {0, 1, 3, 7, 10, 12}, - .expected_token_end_offsets = {1, 2, 7, 10, 11, 13}, - }, - // Test 67: Test suffix indicator "suffix>". Suffix indicator appears in - // the input as a single word after a punctuation. - { - .vocab = {"", "want", "suffix>want", "suffix>ed", "wa", "un", - "runn", "suffix>ing", "#", "su", "suffix>ffix", "suffix"}, - .unk_token = "", - .suffix_indicator = "suffix>", - .max_bytes_per_token = 100, - .input = "#suffix> running", - .expected_tokens = {"#", "suffix", "", "runn", "suffix>ing"}, - .expected_token_ids = {8, 11, 0, 6, 7}, - .expected_token_start_offsets = {0, 1, 7, 9, 13}, - .expected_token_end_offsets = {1, 7, 8, 13, 16}, - }, - // Test 68: Test suffix indicator "suffix>". Suffix indicator appears in - // the input as a single word after a punctuation. - { - .vocab = {"", "want", "suffix>want", "suffix>ed", "wa", "un", - "runn", "suffix>ing", "#", "su", "suffix>ffix"}, - .unk_token = "", - .suffix_indicator = "suffix>", - .max_bytes_per_token = 100, - .input = "#suffix> running", - .expected_tokens = {"#", "su", "suffix>ffix", "", "runn", - "suffix>ing"}, - .expected_token_ids = {8, 9, 10, 0, 6, 7}, - .expected_token_start_offsets = {0, 1, 3, 7, 9, 13}, - .expected_token_end_offsets = {1, 3, 7, 8, 13, 16}, - }, - // Test 69: Test suffix indicator "", "runn", "", "su", "", "runn", ">>". Suffix indicator appears in the - // input. - { - .vocab = {"", "want", ">>>want", ">>>ed", "wa", "un", "runn", - ">>>ing", "#", "su", ">>>ffix"}, - .unk_token = "", - .suffix_indicator = ">>>", - .max_bytes_per_token = 100, - .input = "#suffix>>> running", - .expected_tokens = {"#", "su", ">>>ffix", "", "", "", - "runn", ">>>ing"}, - .expected_token_ids = {8, 9, 10, 0, 0, 0, 6, 7}, - .expected_token_start_offsets = {0, 1, 3, 7, 8, 9, 11, 15}, - .expected_token_end_offsets = {1, 3, 7, 8, 9, 10, 15, 18}, - }, - // Test 72: Test suffix indicator "<", "runn", "<", "runn", "XYZing", "<", "X", "XYZYZ"}, - .unk_token = "", - .suffix_indicator = "XYZ", - .max_bytes_per_token = 100, - .input = "XYZ running", - .expected_tokens = {"X", "XYZYZ", "runn", "XYZing"}, - .expected_token_ids = {4, 5, 1, 2}, - .expected_token_start_offsets = {0, 1, 4, 8}, - .expected_token_end_offsets = {1, 3, 8, 11}, - }, - // Test 74: Test suffix indicator "XYZ", which appears in the - // vocab and input sentence as a single word. - { - .vocab = {"", "runn", "XYZing", "<", "X", "XYZYZ", "XYZ"}, - .unk_token = "", - .suffix_indicator = "XYZ", - .max_bytes_per_token = 100, - .input = "XYZ running", - .expected_tokens = {"XYZ", "runn", "XYZing"}, - .expected_token_ids = {6, 1, 2}, - .expected_token_start_offsets = {0, 4, 8}, - .expected_token_end_offsets = {3, 8, 11}, - }, - // Test suite 5: Test multi-byte punctuation and Chinese characters. - // Test 75: Contains a multi-bytes Unicode punctuation char "\xEF\xBC\x8C" - // followed by a tokenizable word. - { - .vocab = {"", "want", "##ed", "ABC", "\xEF\xBC\x8C", "##ABC"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = "wanted\xEF\xBC\x8C" - "ABC", - .expected_tokens = {"want", "##ed", "\xEF\xBC\x8C", "ABC"}, - .expected_token_ids = {1, 2, 4, 3}, - .expected_token_start_offsets = {0, 4, 6, 9}, - .expected_token_end_offsets = {4, 6, 9, 12}, - }, - // Test 76: Contains a multi-bytes Unicode punctuation char "\xEF\xBC\x8C" - // (absent in the vocab) followed by a tokenizable word. - { - .vocab = {"", "want", "##ed", "ABC", "\xEF\xBC\x8C", "##ABC"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = "wanted\xD7\x83" - "ABC", - .expected_tokens = {"want", "##ed", "", "ABC"}, - .expected_token_ids = {1, 2, 0, 3}, - .expected_token_start_offsets = {0, 4, 6, 8}, - .expected_token_end_offsets = {4, 6, 8, 11}, - }, - // Test 77: Contains a multi-bytes Unicode chinese character \xe4\xb8\x81, - // which is considered as a single word in Bert, so it's treated in the - // same way as punctuation characters by the tokenizer. - { - .vocab = {"", "want", "##ed", "ABC", "\xe4\xb8\x81", "##ABC"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = "wanted\xe4\xb8\x81" - "ABC", - .expected_tokens = {"want", "##ed", "\xe4\xb8\x81", "ABC"}, - .expected_token_ids = {1, 2, 4, 3}, - .expected_token_start_offsets = {0, 4, 6, 9}, - .expected_token_end_offsets = {4, 6, 9, 12}, - }, - // Test 78: Contains a multi-bytes Unicode chinese character \xe4\xb8\x81. - { - .vocab = {"", "want", "##ed", "ABC", "##ABC", - "wanted\xe4\xb8\x81"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = "wanted\xe4\xb8\x81" - "ABC", - .expected_tokens = {"want", "##ed", "", "ABC"}, - .expected_token_ids = {1, 2, 0, 3}, - .expected_token_start_offsets = {0, 4, 6, 9}, - .expected_token_end_offsets = {4, 6, 9, 12}, - }, - // Test 79: Contains a multi-bytes Unicode chinese character \xe4\xb8\x81, - // which is included in the vocab as the suffix of a word. - { - .vocab = {"", "want", "##ed", "ABC", "##ABC", - "wanted\xe4\xb8\x81"}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 10, - .input = "wanted\xe4\xb8\x81" - "ABC", - .expected_tokens = {"want", "##ed", "", "ABC"}, - .expected_token_ids = {1, 2, 0, 3}, - .expected_token_start_offsets = {0, 4, 6, 9}, - .expected_token_end_offsets = {4, 6, 9, 12}, - }}; - return v; -} - -using TestTokenizeText = testing::TestWithParam; - -TEST_P(TestTokenizeText, Test) { - const Spec& spec = GetParam(); - ASSERT_OK_AND_ASSIGN( - std::string flatbuffer, - BuildModelAndExportToFlatBuffer(spec.vocab, spec.max_bytes_per_token, - spec.suffix_indicator, spec.unk_token)); - ASSERT_OK_AND_ASSIGN(auto tokenizer, - FastWordpieceTokenizer::Create(flatbuffer.data())); - - std::vector output_tokens; - std::vector output_ids; - std::vector output_begin_offsets; - std::vector output_end_offsets; - tokenizer.Tokenize(spec.input, &output_tokens, &output_ids, - &output_begin_offsets, &output_end_offsets); - EXPECT_THAT(output_tokens, spec.expected_tokens); - EXPECT_THAT(output_ids, spec.expected_token_ids); - EXPECT_THAT(output_begin_offsets, spec.expected_token_start_offsets); - EXPECT_THAT(output_end_offsets, spec.expected_token_end_offsets); -} - -TEST_P(TestTokenizeText, TestNoOutputPieces) { - const Spec& spec = GetParam(); - ASSERT_OK_AND_ASSIGN( - std::string flatbuffer, - BuildModelAndExportToFlatBuffer(spec.vocab, spec.max_bytes_per_token, - spec.suffix_indicator, spec.unk_token)); - ASSERT_OK_AND_ASSIGN(auto tokenizer, - FastWordpieceTokenizer::Create(flatbuffer.data())); - - std::vector output_ids; - std::vector output_begin_offsets; - std::vector output_end_offsets; - tokenizer.Tokenize(spec.input, &output_ids, &output_begin_offsets, - &output_end_offsets); - EXPECT_THAT(output_ids, spec.expected_token_ids); - EXPECT_THAT(output_begin_offsets, spec.expected_token_start_offsets); - EXPECT_THAT(output_end_offsets, spec.expected_token_end_offsets); -} - -TEST_P(TestTokenizeText, TestNoOutputPiecesOnlyOutputIds) { - const Spec& spec = GetParam(); - ASSERT_OK_AND_ASSIGN( - std::string flatbuffer, - BuildModelAndExportToFlatBuffer(spec.vocab, spec.max_bytes_per_token, - spec.suffix_indicator, spec.unk_token)); - ASSERT_OK_AND_ASSIGN(auto tokenizer, - FastWordpieceTokenizer::Create(flatbuffer.data())); - - std::vector output_ids; - tokenizer.Tokenize(spec.input, &output_ids); - EXPECT_THAT(output_ids, spec.expected_token_ids); -} - -INSTANTIATE_TEST_SUITE_P(EndToEndFastWordpieceTokenizerParameterizedTest, - TestTokenizeText, - testing::ValuesIn(GetTestSpecsForTokenizeText())); - -// Test the detokenization function of FastWordPieceTokenizer. -const std::vector& GetTestSpecsForTokenizeDetokenize() { - static const std::vector& v = *new std::vector{ - // Test 0: Input is a single word. - { - .vocab = {"a", "abc", "##de", "##defgxy", "##deh", "##f", "##ghz", - ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "abcdefghz", - .expected_token_ids = {1, 2, 5, 6}, - .expected_detokenized_text = "abcdefghz", - }, - // Test 1: Input is a sentence. - { - .vocab = {"a", "abc", "##de", "##c", "##f", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "a abc abcde ab", - .expected_token_ids = {0, 1, 1, 2, 5}, - .expected_detokenized_text = "a abc abcde ", - }, - // Test 2: Input has the leading suffix indicator. - { - .vocab = {"a", "abc", "##de", "##deh", "##f", ""}, - .unk_token = "", - .suffix_indicator = "##", - .max_bytes_per_token = 100, - .input = "##deh abcde", - .expected_token_ids = {3, 1, 2}, - .expected_detokenized_text = "##deh abcde", - }, - }; - return v; -} -using TestTokenizeDetokenize = testing::TestWithParam; - -TEST_P(TestTokenizeDetokenize, Test) { - const Spec& spec = GetParam(); - ASSERT_OK_AND_ASSIGN( - std::string flatbuffer, - BuildModelAndExportToFlatBuffer(spec.vocab, spec.max_bytes_per_token, - spec.suffix_indicator, spec.unk_token, - /*no_pretokenization=*/true, - /*support_detokenization=*/true)); - ASSERT_OK_AND_ASSIGN(auto tokenizer, - FastWordpieceTokenizer::Create(flatbuffer.data())); - - // Test detokenization. - ASSERT_OK_AND_ASSIGN(auto output_text, - tokenizer.Detokenize(spec.expected_token_ids)); - EXPECT_THAT(output_text, spec.expected_detokenized_text); -} - -INSTANTIATE_TEST_SUITE_P( - FastWordpieceTokenizerDetokenizeParameterizedTest, TestTokenizeDetokenize, - testing::ValuesIn(GetTestSpecsForTokenizeDetokenize())); - -} // namespace -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_tflite.cc b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_tflite.cc deleted file mode 100644 index 18c58af04..000000000 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_tflite.cc +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_tflite.h" - -#include "tensorflow/lite/kernels/shim/tflite_op_shim.h" -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -using TokenizeOpKernel = tflite::shim::TfLiteOpKernel< - tensorflow::text::FastWordpieceTokenizeWithOffsetsOp>; - -using DetokenizeOpKernel = - tflite::shim::TfLiteOpKernel; - -extern "C" void AddFastWordpieceTokenize(tflite::MutableOpResolver* resolver) { - TokenizeOpKernel::Add(resolver); -} - -extern "C" void AddFastWordpieceDetokenize( - tflite::MutableOpResolver* resolver) { - DetokenizeOpKernel::Add(resolver); -} - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_tflite.h b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_tflite.h index bd1c176c5..ad1f9b6fe 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_tflite.h +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_tflite.h @@ -12,24 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_TENSORFLOW_TEXT_GOOGLE_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_TFLITE_H_ -#define THIRD_PARTY_TENSORFLOW_TEXT_GOOGLE_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_TFLITE_H_ +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_TFLITE_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_TFLITE_H_ -#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/core/kernels/text/fast_wordpiece_tokenizer_tflite.h" -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -extern "C" void AddFastWordpieceTokenize(::tflite::MutableOpResolver* resolver); - -extern "C" void AddFastWordpieceDetokenize( - ::tflite::MutableOpResolver* resolver); - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite - -#endif // THIRD_PARTY_TENSORFLOW_TEXT_GOOGLE_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_TFLITE_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_TFLITE_H_ diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_utils.h b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_utils.h index a147654f0..a5f6df1bd 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_utils.h +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_utils.h @@ -12,261 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -// To optimize speed/memory usage, we assume: -// * The WordPiece vocabulary has at most 2^22 = 4M tokens. -// * No token from the vocabulary has more than 256 bytes. -// -// The assumptions are adjustable by setting the constants defined in this file. -// -// Note: by recompiling the underlying trie library and the helper functions in -// this file to use 64-bit (or even larger) integers, we can support even a -// larger vocab size and longer vocab tokens. Still, we believe the current -// implementation covers all real cases. #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_UTILS_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_UTILS_H_ -#include - -#include - -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "icu4c/source/common/unicode/uchar.h" -#include "icu4c/source/common/unicode/umachine.h" - -namespace tensorflow { -namespace text { -namespace fast_wordpiece_tokenizer_utils { - -// This header assumes that is 32-bit integer types. -static_assert(sizeof(int) == 4, "FastWordpieceTokenizer requires 4-byte int."); - -//////////////////////////////////////////////////////////////////////////////// -// Constants for token encoding. -// -// The constants below define a 32-bit compact token representation that encodes -// (1) the token id, (2) the token length (minus 1, and without the suffix -// indicator, in utf-8 bytes), and (3) is_suffix_token (i.e., the token starts -// with the suffix indicator (say) "##"). -// -// The encoded value is stored on the darts_clone trie as well as in the -// `failure_pops_pool` (see FastWordpieceTokenizerConfig in -// fast_wordpiece_tokenizer_model.fbs). As required by darts_clone_trie, the -// type of the encoded value should be 32-bit signed int, and the top bit is -// reserved to be always 0. -// -// Examples (given the existing constants; bits are numbered 0 to 31 from -// right/lower to left/upper; the top bit is reserved by darts_clone trie and is -// always 0): -// * Token "a", token id 0 -> The encoded value is 0x0: -// * bit 31: 0. -// * bit 30: 0, since token "a" is not a suffix token. -// * bits 29-8: 0, since the token id is 0. -// * bits 7-0: 0, since the encoded token length is 0 (see below comments). -// * Token "b", token id 1 -> The encoded value is 0x100: -// * bit 31: 0. -// * bit 30: 0, since token "b" is not a suffix token. -// * bits 29-8: 1, since the token id is 1. -// * bits 7-0: 0, since the encoded token length is 0 (see below comments). -// * Token "##b", token id 2 -> The encoded value is 0x40000200: -// * bit 31: 0. -// * bit 30: 1, since token "##b" is a suffix token. -// * bits 29-8: 2, since the token id is 2. -// * bits 7-0: 0, since the encoded token length is 0 (see below comments). -// * Token "bc", token id 3 -> The encoded value is 0x301: -// * bit 31: 0. -// * bit 30: 0, since token "bc" is not a suffix token. -// * bits 29-8: 3, since the token id is 3. -// * bits 7-0: 1, since the encoded token length is 1 (see below comments). -// * Token "##bcd", token id 5 -> The encoded value is 0x40000502: -// * bit 31: 0. -// * bit 30: 1, since token "##bcd" is a suffix token. -// * bits 29-8: 5, since the token id is 5. -// * bits 7-0: 2, since the encoded token length is 2 (see below comments). -// -// One special case is that when the suffix indicator is the empty string "". In -// this case, `is_suffix_token` is false for all tokens. -// -// Another special case is that when the suffix indicator string happens to be a -// token in the vocabulary. When encoding such a token like "##", by design, -// `is_suffix_token` is false, and the encoded token length is the full length -// of the suffix indicator string. -// -//////////////////////////////////////////////////////////////////////////////// - -// The (right-to-left 0-based) bit to encode whether the token is a suffix -// token. -static constexpr uint32_t kBitToIndicateSuffixToken = 30; - -// The number of low bits to encode the vocab token length into a compact -// representation. Technically, we encode the length of the token without the -// suffix indicator (if any) minus 1. Examples: -// * Token "a" -> we encode 1-1 = 0. -// * Token "abc" -> we encode 3-1 = 0. -// * Token "##abc" -> we encode 2, as before (we ignore the suffix indicator). -static constexpr uint32_t kBitsToEncodeVocabTokenLength = 8; - -// The bit mask to get the vocab token length from the compact representation. -static constexpr uint32_t kMaskToEncodeVocabTokenLength = - (1 << kBitsToEncodeVocabTokenLength) - 1; - -// Max vocab token length supported (given `kBitsToEncodeVocabTokenLength`). -static constexpr uint32_t kMaxVocabTokenLengthInUTF8Bytes = - (1 << kBitsToEncodeVocabTokenLength); - -// The maximum vocab size supported by our 32-bit encoding. Using right-to-left -// 0-based numbering, Bit 31 is reserved by darts_clone trie. Bit 30 indicates -// whether the token is a suffix token. The low `kBitsToEncodeVocabTokenLength` -// bits encode the token length. Given `kBitsToEncodeVocabTokenLength=8`, this -// leaves 32-1-1-8=22 bits for token ids, i.e., a max vocab size of 2^22 = 4M. -static constexpr uint32_t kMaxSupportedVocabSize = - (1 << (32 - 1 - 1 - kBitsToEncodeVocabTokenLength)); - -// The bit mask to get the vocab token id from the compact representation. -static constexpr uint32_t kMaskToEncodeVocabTokenId = - ((1 << kBitToIndicateSuffixToken) - 1) ^ kMaskToEncodeVocabTokenLength; - -//////////////////////////////////////////////////////////////////////////////// -// Helpers for encoding / decoding tokens. -//////////////////////////////////////////////////////////////////////////////// - -// Encodes a token into the encoded value. `token_length` is without the suffix -// indicator. The result is always a non-negative integer. Only used in building -// the model (in flatbuffer), not in doing WordPiece tokenization. -inline absl::StatusOr EncodeToken(int token_id, int token_length, - bool is_suffix_token) { - const int encoded_value = (is_suffix_token << kBitToIndicateSuffixToken) | - (token_id << kBitsToEncodeVocabTokenLength) | - (token_length - 1); - if (encoded_value < 0) { - return absl::FailedPreconditionError(absl::StrCat( - "EncodeToken() must return a non-negative value! Found encoded value: ", - encoded_value, " for input token id: ", token_id, ", token_length: ", - token_length, ", is_suffix_token: ", is_suffix_token)); - } - return encoded_value; -} - -// Gets whether it is a suffix token from the encoded value. -inline bool IsSuffixToken(int token_encoded_value) { - return static_cast(token_encoded_value >> kBitToIndicateSuffixToken); -} - -// Gets the token id from the encoded value. -inline int GetTokenId(int token_encoded_value) { - return (token_encoded_value & kMaskToEncodeVocabTokenId) >> - kBitsToEncodeVocabTokenLength; -} - -// Gets the token length (without the suffix indicator) from the encoded value. -inline int GetTokenLength(int token_encoded_value) { - return (token_encoded_value & kMaskToEncodeVocabTokenLength) + 1; -} - -//////////////////////////////////////////////////////////////////////////////// -// Constants for encoding failure pop lists. -// -// We put all failure pop lists into a common pool. The constants below define -// the compact representation that encodes (1) the offset, and (2) the length -// (minus 1) for a failure pop list in the common pool. -// -// Examples (given the existing constants; bits are numbered 0 to 31 from -// right/lower to left/upper): -// * failure pop list A, whose offset is 0 and length is 1 -> The encoded value -// is 0x0: -// * bits 31-8: 0, since the offset is 0. -// * bits 7-0: 0, since the encoded length is 0 (=1-1). -// * failure pop list B, whose offset is 0 and length is 3 -> The encoded value -// is 0x2: -// * bits 31-8: 0, since the offset is 0. -// * bits 7-0: 2, since the encoded length is 2 (=3-1). -// * failure pop list C, whose offset is 11 and the length is 10 -> The encoded -// value is 0xB09: -// * bits 31-8: 0xB, since the offset is 11. -// * bits 7-0: 9, since the encoded length is 9 (=10-1). -//////////////////////////////////////////////////////////////////////////////// - -// The number of low bits used to encode the length of failure pops minus 1 in -// the compact representation. This value should be less than or equal to -// `kBitsToEncodeVocabTokenLength`, since the size of failure pops is bounded by -// the maximum token length in the vocabulary. -static constexpr uint32_t kBitsToEncodeFailurePopsListSize = - kBitsToEncodeVocabTokenLength; - -// The bit mask to get the length of the failure pop list (without any suffix -// indicator, and minus 1) from the compact representation. -static constexpr uint32_t kMaskToEncodeFailurePopsListSize = - (1 << kBitsToEncodeFailurePopsListSize) - 1; - -// Max length of the failure pop list supported (given -// `kBitsToEncodeFailurePopsListSize`). -static constexpr uint32_t kMaxFailurePopsListSize = - (1 << kBitsToEncodeFailurePopsListSize); - -// The maximum valid offset in the failure pool, excluding the largest one -// (i.e., 0xFF...F), which is reserved to denote a null failure pop list (see -// `kNullFailurePopsList`). -static constexpr uint32_t kMaxSupportedFailurePoolOffset = - (1 << (32 - kBitsToEncodeFailurePopsListSize)) - 1 - 1; - -// Represents the null failure pops list, because 0xFF...F is not a valid of -// offset (see `kMaxSupportedFailurePoolOffset`). -static constexpr uint32_t kNullFailurePopsList = - std::numeric_limits::max(); - -//////////////////////////////////////////////////////////////////////////////// -// Helpers for encoding / decoding failure pop lists -//////////////////////////////////////////////////////////////////////////////// - -// Encodes the offset (in the failure pop pool) and the length of a failure pop -// list into an integer for a compact representation. -inline uint32_t EncodeFailurePopList(int offset, int length) { - return (offset << kBitsToEncodeFailurePopsListSize) | (length - 1); -} - -// Decodes the offset (in the failure pop pool) and the length of a failure pop -// list from the compact representation (an integer). -inline void GetFailurePopsOffsetAndLength(uint32_t offset_and_length, - int& out_offset, int& out_length) { - out_offset = offset_and_length >> kBitsToEncodeFailurePopsListSize; - out_length = (offset_and_length & kMaskToEncodeFailurePopsListSize) + 1; -} - -//////////////////////////////////////////////////////////////////////////////// -// Constants related to the Trie structure. -//////////////////////////////////////////////////////////////////////////////// - -// Represents the null node id. Different from any normal node. -static constexpr uint32_t kNullNode = std::numeric_limits::max(); - -// The maximum trie size supported. Because std::numeric_limits::max() -// (i.e., 0xFFFFFFFF) is reserved to represent the null node, the total trie -// size needs to be smaller or equal to 0xFFFFFFFF. -static constexpr uint32_t kMaxSupportedTrieSize = - std::numeric_limits::max(); - -//////////////////////////////////////////////////////////////////////////////// -// Helpers for analyzing Unicode characters. -//////////////////////////////////////////////////////////////////////////////// -inline bool IsPunctuationOrChineseChar(UChar32 char_value) { - uint32_t cp = static_cast(char_value); - // Chinese characters that are treated as punctuation in Bert. - if ((cp >= 0x4E00 && cp <= 0x9FFF) || (cp >= 0x3400 && cp <= 0x4DBF) || - (cp >= 0x20000 && cp <= 0x2A6DF) || (cp >= 0x2A700 && cp <= 0x2B73F) || - (cp >= 0x2B740 && cp <= 0x2B81F) || (cp >= 0x2B820 && cp <= 0x2CEAF) || - (cp >= 0xF900 && cp <= 0xFAFF) || (cp >= 0x2F800 && cp <= 0x2FA1F)) { - return true; - } - // Some special chars e.g. ">", "$" that are not covered by the u_ispunct are - // considered as punctuation chars. - if ((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) || - (cp >= 91 && cp <= 96) || (cp >= 123 && cp <= 126)) { - return true; - } - return u_ispunct(char_value); -} -} // namespace fast_wordpiece_tokenizer_utils -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/fast_wordpiece_tokenizer_utils.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_UTILS_H_ diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_utils_test.cc b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_utils_test.cc deleted file mode 100644 index a36900542..000000000 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_utils_test.cc +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_utils.h" - -#include -#include - -namespace tensorflow { -namespace text { -namespace fast_wordpiece_tokenizer_utils { -namespace { - -// Testing spec struct for token encoding / decoding. -struct TokenSpec { - friend std::ostream& operator<<(std::ostream& os, const TokenSpec& s) { - return os << "token_id:" << s.token_id << ", " - << "token_length:" << s.token_length << ", " - << "is_suffix_token:" << s.is_suffix_token << std::endl; - } - - int token_id; - int token_length; - bool is_suffix_token; -}; - -// Parameterized tests specs for token encoding / decoding. -const std::vector& GetTokenSpecs() { - static const std::vector& kSpecs = *new std::vector{ - // Test 0. - { - .token_id = 0, - .token_length = 1, - .is_suffix_token = false, - }, - // Test 1. - { - .token_id = 1, - .token_length = 1, - .is_suffix_token = false, - }, - // Test 2. - { - .token_id = 2, - .token_length = 1, - .is_suffix_token = true, - }, - // Test 3. - { - .token_id = 3, - .token_length = 10, - .is_suffix_token = false, - }, - // Test 4. - { - .token_id = 4, - .token_length = 10, - .is_suffix_token = true, - }, - // Test 5. - { - .token_id = kMaxSupportedVocabSize - 1, - .token_length = kMaxVocabTokenLengthInUTF8Bytes, - .is_suffix_token = true, - }, - }; - return kSpecs; -} - -using TokenEncodingDecodingTest = testing::TestWithParam; - -TEST_P(TokenEncodingDecodingTest, GeneralTest) { - const TokenSpec& spec = GetParam(); - ASSERT_OK_AND_ASSIGN( - auto encoded_value, - EncodeToken(spec.token_id, spec.token_length, spec.is_suffix_token)); - EXPECT_THAT(GetTokenId(encoded_value), spec.token_id); - EXPECT_THAT(GetTokenLength(encoded_value), spec.token_length); - EXPECT_THAT(IsSuffixToken(encoded_value), spec.is_suffix_token); -} - -INSTANTIATE_TEST_SUITE_P(TestTokenEncodingDecoding, TokenEncodingDecodingTest, - testing::ValuesIn(GetTokenSpecs())); - -struct FailurePopListSpec { - friend std::ostream& operator<<(std::ostream& os, - const FailurePopListSpec& s) { - return os << "offset:" << s.offset << ", " - << "length:" << s.length << std::endl; - } - - int offset; - int length; -}; - -// Parameterized tests specs for failure pop list encoding and decoding. -const std::vector& GetFailurePopListSpecs() { - static const std::vector& kSpecs = - *new std::vector{ - // Test 0. - { - .offset = 0, - .length = 1, - }, - // Test 1. - { - .offset = 0, - .length = 3, - }, - // Test 2. - { - .offset = 11, - .length = 10, - }, - // Test 3. - { - .offset = kMaxSupportedFailurePoolOffset, - .length = kMaxFailurePopsListSize, - }, - }; - return kSpecs; -} - -using FailurePopListEncodingDecodingTest = - testing::TestWithParam; - -TEST_P(FailurePopListEncodingDecodingTest, GeneralTest) { - const FailurePopListSpec& spec = GetParam(); - auto offset_and_length = EncodeFailurePopList(spec.offset, spec.length); - int offset, length; - GetFailurePopsOffsetAndLength(offset_and_length, offset, length); - EXPECT_THAT(offset, spec.offset); - EXPECT_THAT(length, spec.length); -} - -INSTANTIATE_TEST_SUITE_P(TestFailurePopListEncodingDecoding, - FailurePopListEncodingDecodingTest, - testing::ValuesIn(GetFailurePopListSpecs())); - -} // namespace -} // namespace fast_wordpiece_tokenizer_utils -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/log_greedy_constrained_sequence_kernel_test.cc b/tensorflow_text/core/kernels/log_greedy_constrained_sequence_kernel_test.cc deleted file mode 100644 index 6d9d89054..000000000 --- a/tensorflow_text/core/kernels/log_greedy_constrained_sequence_kernel_test.cc +++ /dev/null @@ -1,799 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow_text/core/kernels/text_kernels_test_util.h" - -namespace tensorflow { - -using tensorflow::DT_INT32; -using tensorflow::FakeInput; -using tensorflow::NodeDefBuilder; -using tensorflow::Status; -using tensorflow::TensorShape; -using tensorflow::text_kernels_test_util::MatrixEq; -using tensorflow::text_kernels_test_util::VectorEq; - -class LogGreedyConstrainedSequenceTest : public tensorflow::OpsTestBase { - public: - void SetUpOpWithDefaults() { - // Prepare graph. - TF_ASSERT_OK(NodeDefBuilder("tested_op", "ConstrainedSequence") - .Attr("Tin", DT_INT32) - .Attr("use_viterbi", false) - .Attr("use_log_space", true) - .Attr("use_start_and_end_states", true) - .Input(FakeInput()) - .Input(FakeInput()) - .Input(FakeInput()) - .Input(FakeInput()) - .Finalize(node_def())); - TF_ASSERT_OK(InitOp()); - } -}; - -// TODO(b/122968457): There are a bunch of tests that only validate !ok instead -// of looking for specific error messages; fix that. - -// This test examines evaluations with only a permissions matrix. -TEST_F(LogGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNoWeights) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with an empty weights matrix not of rank 2. -TEST_F(LogGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNonMatrixEmptyWeights) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with a 2D score matrix (implicit batch 1). -TEST_F(LogGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithSingleBatchItem) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({1, 4}), // - { - 10.0, 12.0, 13.0, 4.0, // - }); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({1}), {1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // Validate the output. - std::vector expected_transitions({1}); - std::vector expected_offsets({0, 1}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines int64 input type and int32 output type. -TEST_F(LogGreedyConstrainedSequenceTest, int64inint32out) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - // Validate the output. - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures the op can take a sequence length of type {{X},{Y},{Z}} -// (with an outer batch dimension). -TEST_F(LogGreedyConstrainedSequenceTest, TwoDimensionalSequenceLengths) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3, 1}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures that final transitions that are forbidden by the permission -// matrix (final->null) are not taken. -TEST_F(LogGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNoWeightsConstrainedByEnd) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, false, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok; the next - // highest is 1, but 1->OUT is not OK; the next highest is 0, which is OK. - // The second sequence's highest score is 3, OUT->3 is OK and 3->OUT is OK. - // The third sequence's highest score is 0, OUT->0 is OK and 0->OUT is OK. - // Validate the output. - std::vector expected_transitions({0, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with only a weight matrix. -TEST_F(LogGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNoPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be summed with the last row in the weight tensor, so - // the 'real' scores are: - // 1: {10.1, 2.5, 7.5, 5.0} (max is 0) - // 2: {1.1, 9.5, 11.5, 6.0} (max is 2) - // 3: {100.1, 24.5, 3.5, 5.0} (max is 0) - // Validate the output. - std::vector expected_transitions({0, 2, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with an empty not rank 2 permissions matrix. -TEST_F(LogGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNonMatrixEmptyPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be summed with the last row in the weight tensor, so - // the 'real' scores are: - // 1: {10.1, 2.5, 7.5, 5.0} (max is 0) - // 2: {1.1, 9.5, 11.5, 6.0} (max is 2) - // 3: {100.1, 24.5, 3.5, 5.0} (max is 0) - // Validate the output. - std::vector expected_transitions({0, 2, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures that final transitions are scored with the probability -// of ending the sequence on the transition (x->final->null). -TEST_F(LogGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNoPermissionsWeightedByEnd) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 0.1, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be summed with the last row and the last column in the - // score tensor, so the real scores are: - // 1: {10.1, 2.5, 7.5, 4.1} (max is 0) - // 2: {1.1, 9.5, 11.5, 6.0} (max is 2) - // 3: {100.1, 24.5, 3.5, 5.0} (max is 0) - // Validate the output. - std::vector expected_transitions({0, 2, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with both weight and permission matrices. -TEST_F(LogGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithWeightsAndPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 7.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, false, // FROM 2 - true, true, true, true, true, // FROM 3 - false, true, true, true, false, // FROM 'OUT' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 1.0, // - 0.5, 0.5, 0.5, 0.5, 0.1, // - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be summed with the last row and the last column in the - // score tensor, so the real scores are: - // 1: {7.1, 2.5, 7.5, 4.1} (max is 3, but 2->NUL/NUL->0 is not OK, so 3.) - // 2: {1.1, 9.5, 11.5, 6.0} (max is 2, but 2->NUL is not OK, so 1.) - // 3: {100.1, 24.5, 3.5, 5.0} (max is 0, but NUL->0 is not OK, so 1.) - // Validate the output. - std::vector expected_transitions({3, 1, 1}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines multiple evaluations with both weight and permission -// matrices. -TEST_F(LogGreedyConstrainedSequenceTest, - ComputesMultipleTransitionsWithWeightsAndPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 2, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // Batch 0, step 0 - 10.0, 10.0, 10.0, 10.0, // Batch 0, step 1 - 1.0, 9.0, 11.0, 5.0, // Batch 1, step 0 - 10.0, 15.0, 1.0, 12.0, // Batch 1, step 1 - 100.0, 24.0, 3.0, 4.0, // Batch 2, step 0 - 1.0, 11.0, 1.0, 10.0, // Batch 2, step 1 - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {2, 2, 2}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO NUL - true, true, true, true, true, // FROM 0 - true, true, true, true, false, // FROM 1 - true, false, true, false, true, // FROM 2 - true, true, true, true, true, // FROM 3 (OUT) - false, true, true, true, true, // FROM 'NULL' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // 0 - 0.5, 0.5, 0.5, 0.5, 1.0, // 1 - 0.5, 0.5, 1.0, 0.5, 1.0, // 2 - 0.5, 0.5, 0.5, 0.5, 1.0, // 3 - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // STEP 1: - // All scores should be summed with the last row in the weight tensor, so - // the 'real' scores are: - // 1: {10.1, 2.5, 7.5, 5.0} (max is 2). OUT->2 is OK. - // 2: {1.1, 9.5, 11.5, 6.0} (max is 2). OUT->2 is OK. - // 3: {100.1, 11.5, 1.5, 11.0} (max is 0). OUT->0 is not OK, so go with 1. - // STEP 2: - // 1: In state '2', so use row 2 in the weight tensor. - // Weights are {11.5, 11.5, 12.0, 11.5}; 2->2 is OK and 2->OUT is OK; use 2. - // 2: In state '2', so use row 2 in the weight tensor. - // Weights are {10.5, 15.5, 2.0, 13.0}; 2->3 is not OK and 2->1 is not OK, so - // 0. 3: In state 0, so use row 0 in the weight tensor. Weights are - // {1.5, 11.5, 1.5, 11}; 0->1 is OK but 1->OUT is not, so 3. - - std::vector expected_transitions({2, 2, 2, 0, 1, 3}); - std::vector expected_offsets({0, 2, 4, 6}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} -// This test examines multiple evaluations with both weight and permission -// matrices. -TEST_F(LogGreedyConstrainedSequenceTest, - ComputesMultipleTransitionsWithVaryingLengths) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 2, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // Batch 0, step 0 - 10.0, 10.0, 10.0, 10.0, // Batch 0, step 1 - 1.0, 9.0, 11.0, 5.0, // Batch 1, step 0 - 10.0, 15.0, 1.0, 12.0, // Batch 1, step 1 - 100.0, 24.0, 3.0, 4.0, // Batch 2, step 0 - 1.0, 11.0, 1.0, 10.0, // Batch 2, step 1 - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {2, 1, 2}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO NUL - true, true, true, true, true, // FROM 0 - true, true, true, true, false, // FROM 1 - true, false, true, false, true, // FROM 2 - true, true, true, true, true, // FROM 3 (OUT) - false, true, true, true, true, // FROM 'NULL' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.5, 0.5, 0.5, 0.5, 1.0, // 0 - 0.5, 0.5, 0.5, 0.5, 1.0, // 1 - 0.5, 0.5, 1.0, 0.5, 1.0, // 2 - 0.5, 0.5, 0.5, 0.5, 1.0, // 3 - 0.1, 0.5, 0.5, 1.0, 1.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // STEP 1: - // All scores should be summed with the last row in the weight tensor, so - // the 'real' scores are: - // 1: {10.1, 2.5, 7.5, 5.0} (max is 2). OUT->2 is OK. - // 2: {1.1, 9.5, 11.5, 6.0} (max is 2). OUT->2 and 2->OUT are OK. - // 3: {100.1, 11.5, 1.5, 11.0} (max is 0). OUT->0 is not OK, so go with 1. - // STEP 2: - // 1: In state '2', so use row 2 in the weight tensor. - // Weights are {11.5, 11.5, 12.0, 11.5}; 2->2 is OK and 2->OUT is OK; use 2. - // 2: End of sequence. - // 3: In state 0, so use row 0 in the weight tensor. - // Weights are {1.5, 11.5, 1.5, 11}; 0->1 is OK but 1->OUT is not, so 3. - - std::vector expected_transitions({2, 2, 2, 1, 3}); - std::vector expected_offsets({0, 2, 3, 5}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with a fully negative input set. -TEST_F(LogGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithNegativeInputs) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - -10.0, -12.0, -13.0, -4.0, // - -1.0, -12.0, -13.0, -14.0, // - -15.0, -2.0, -3.0, -14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, true, true, true, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - std::vector expected_transitions({3, 0, 1}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with an all-zero weight matrix. -TEST_F(LogGreedyConstrainedSequenceTest, - ComputesSingleTransitionWithZeroedWeights) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - 100.0, 24.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), { - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, - }); - - TF_ASSERT_OK(RunOpKernel()); - - // Because all weights are zero, the max values should be the max of the - // scores. - std::vector expected_transitions({0, 2, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -TEST_F(LogGreedyConstrainedSequenceTest, - ImpossibleSequencesResultInNegativeOnesIfAttrIsSet) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 2, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {2, 2, 2}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - false, false, false, false, false, // FROM 0 - false, false, false, false, false, // FROM 1 - false, false, false, false, false, // FROM 2 - false, false, false, false, false, // FROM 3 - false, false, false, false, false, // FROM 'OUT' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // Validate the output. - - std::vector expected_transitions({-1, -1, -1, -1, -1, -1}); - std::vector expected_offsets({0, 2, 4, 6}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures the op will throw an error if there are too few scores to -// finalize all the sequences. -TEST_F(LogGreedyConstrainedSequenceTest, ErrorsIfGivenInsufficientScores) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 2, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} - -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/log_viterbi_constrained_sequence_kernel_test.cc b/tensorflow_text/core/kernels/log_viterbi_constrained_sequence_kernel_test.cc deleted file mode 100644 index 7e444a496..000000000 --- a/tensorflow_text/core/kernels/log_viterbi_constrained_sequence_kernel_test.cc +++ /dev/null @@ -1,815 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow_text/core/kernels/text_kernels_test_util.h" - -namespace tensorflow { - -using tensorflow::DT_INT32; -using tensorflow::FakeInput; -using tensorflow::NodeDefBuilder; -using tensorflow::Status; -using tensorflow::TensorShape; -using tensorflow::text_kernels_test_util::MatrixEq; -using tensorflow::text_kernels_test_util::VectorEq; - - -// TODO(b/122968457): There are a bunch of tests that only validate !ok instead -// of looking for specific error messages; fix that. - -class LogViterbiConstrainedSequenceTest : public tensorflow::OpsTestBase { - public: - void SetUpOpWithDefaults() { - // Prepare graph. - TF_ASSERT_OK(NodeDefBuilder("tested_op", "ConstrainedSequence") - .Attr("Tin", DT_INT32) - .Attr("use_viterbi", true) - .Attr("use_log_space", true) - .Attr("use_start_and_end_states", true) - .Input(FakeInput()) - .Input(FakeInput()) - .Input(FakeInput()) - .Input(FakeInput()) - .Finalize(node_def())); - TF_ASSERT_OK(InitOp()); - } -}; - -// This test examines evaluations with only a permissions matrix. -TEST_F(LogViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithNoWeights) { - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with an empty weights matrix not of rank 2. -TEST_F(LogViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithNonMatrixEmptyWeights) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with a 2D score matrix (implicit batch 1). -TEST_F(LogViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithSingleBatchItem) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({1, 4}), // - { - 10.0, 12.0, 13.0, 4.0, // - }); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({1}), {1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // Validate the output. - std::vector expected_transitions({1}); - std::vector expected_offsets({0, 1}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines int64 input type and int32 output type. -TEST_F(LogViterbiConstrainedSequenceTest, int64inint32out) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - // Validate the output. - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures the op can take a sequence length of type {{X},{Y},{Z}} -// (with an outer batch dimension). -TEST_F(LogViterbiConstrainedSequenceTest, TwoDimensionalSequenceLengths) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3, 1}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok, so it's 1. - // The second sequence's highest score is 3, which is ok. - // The third sequence's highest score is 0, which is ok. - - // Validate the output. - std::vector expected_transitions({1, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures that final transitions that are forbidden by the permission -// matrix (final->null) are not taken. -TEST_F(LogViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithNoWeightsConstrainedByEnd) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, false, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // The first sequence's highest score is 2, but OUT->2 is not ok; the next - // highest is 1, but 1->OUT is not OK; the next highest is 0, which is OK. - // The second sequence's highest score is 3, OUT->3 is OK and 3->OUT is OK. - // The third sequence's highest score is 0, OUT->0 is OK and 0->OUT is OK. - // Validate the output. - std::vector expected_transitions({0, 3, 0}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with only a weight matrix. -TEST_F(LogViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithNoPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - -12.0, 3.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 10.0, 5.0, 3.0, 1.0, 0.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be summed with the last row in the weight tensor, so - // the 'real' scores are: - // 1: {20.0, 7.0, 10.0, 5.0} (max is 0) - // 2: {11.0, 14.0, 14.0, 6.0} (max is 2, due to tiebreaker.) - // 3: {-2.0, 8.0, 6.0, 5.0} (max is 1) - // Validate the output. - std::vector expected_transitions({0, 2, 1}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with an empty not rank 2 permissions matrix. -TEST_F(LogViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithNonMatrixEmptyPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - -12.0, 3.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 10.0, 5.0, 3.0, 1.0, 0.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be summed with the last row in the weight tensor, so - // the 'real' scores are: - // 1: {20.0, 7.0, 10.0, 5.0} (max is 0) - // 2: {11.0, 14.0, 14.0, 6.0} (max is 2, due to tiebreaker.) - // 3: {-2.0, 8.0, 6.0, 5.0} (max is 1) - // Validate the output. - std::vector expected_transitions({0, 2, 1}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures that final transitions are scored with the probability -// of ending the sequence on the transition (x->final->null). -TEST_F(LogViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithNoPermissionsWeightedByEnd) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - -12.0, 3.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({0, 0}), {}); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.0, 0.0, 0.0, 0.0, -15.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 10.0, 5.0, 3.0, 1.0, 0.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be summed with the last row in the weight tensor, so - // the 'real' scores are: - // 1: {5.0, 7.0, 10.0, 5.0} (max is 2 - state 1->null adds -15.) - // 2: {11.0, 14.0, 14.0, 6.0} (max is 2, due to tiebreaker.) - // 3: {-2.0, 8.0, 6.0, 5.0} (max is 1) - // Validate the output. - std::vector expected_transitions({2, 2, 1}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with both weight and permission matrices. -TEST_F(LogViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithWeightsAndPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 2.0, 7.0, 4.0, // - 1.0, 9.0, 11.0, 5.0, // - -12.0, 3.0, 3.0, 4.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 'OUTSIDE' - false, true, true, true, false, // FROM 'NULL' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), {0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 0.0, 0.0, 0.0, 0.0, 0.0, // - 10.0, 5.0, 3.0, 1.0, 0.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // All scores should be summed with the last row in the weight tensor, so - // the 'real' scores are: - // 1: {20.0, 7.0, 10.0, 5.0} (max is 0, but NUL->0 is forbidden, so 2.) - // 2: {11.0, 14.0, 14.0, 6.0} (max is 2, due to tiebreaker.) - // 3: {-2.0, 8.0, 6.0, 5.0} (max is 1) - // Validate the output. - std::vector expected_transitions({2, 2, 1}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines multiple evaluations with both weight and permission -// matrices. -TEST_F(LogViterbiConstrainedSequenceTest, - ComputesMultipleTransitionsWithWeightsAndPermissions) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({2, 2, 4}), // - {{ - 10.0, 12.0, 7.0, 4.0, // Batch 0, step 0 - 13.0, 12.0, 11.0, 10.0, // Batch 0, step 1 - 7.0, 9.0, 11.0, 5.0, // Batch 1, step 0 - 10.0, 15.0, 1.0, 12.0, // Batch 1, step 1 - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({2}), {2, 2}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO NUL - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, false, true, false, false, // FROM 2 - true, true, true, true, true, // FROM 3 (OUT) - true, false, true, true, true, // FROM 'NULL' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), - {-1.0, 1.0, -2.0, 2.0, 0.0, // 0 - 3.0, -3.0, 4.0, -4.0, 0.0, // 1 - 5.0, -5.0, 6.0, -6.0, 0.0, // 2 - -7.0, 7.0, -8.0, 8.0, 0.0, // 3 - 0.0, 1.0, 2.0, 3.0, 0.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // STEP 1: - // All scores should be summed with the last row in the weight tensor, so the - // 'real' scores are: - // B0: { 10.0, [NOTOK], 9.0, 7.0} - // B1: { 7.0, [NOTOK], 13.0, 8.0} - // - // STEP 2: - // (Forbidden transitions are marked with '*' and X stands for the lowest - // possible score.) - // - // BATCH 0: - // Raw scores are: {13.0, 12.0, 11.0, 10.0} - // - // Final state 0: (13.0) Weighted scores are {12.0, 16.0, 18.0, 6.0} - // New totals are {22, X, 27, 18} [max 27 from 2] - // - // Final state 1: (12.0) Weighted scores are {13.0, 9.0, X, 19.0}, - // New totals are {23, X, X, 26} [max 26 from 3] - // - // Final state 2: (11.0) Weighted scores are {9, 15, 21, 3}, - // New totals are {19, X, 30, 10} [max 30 from 2] - // - // Final state 3: (10.0) Weighted scores are {12, 6, X, 18}, - // New totals are {19, X, X, 25} [max 25 from 3] - // - // Top scores are [27, 26, 30, 25] from [2, 3, 2, 3]. - // 2->OUT is X, so final scores are [27, 26, X, 25] for a - // final state of [0] with a sequence of [2->0]. - // - // - // BATCH 1: - // Previous scores are {7, X, 13, 8} - // Raw scores are {10, 15, 1, 12} - // - // Final state 0: Weighted score is {9, 18, 15, 3} - // New totals are {16, X, 28, 11} [max 28 from 2] - // - // Final state 1: Weighted score is {16, 12, 10, 22} - // New totals are {23, X, X*, 30} [max 30 from 3] - // - // Final state 2: Weighted score is {-1, 5, 7, -7} - // New totals are {6, X, 20, 1} [max 20 from 2] - // - // Final state 3: Weighted score is {14, 8, 6, 20} - // New totals are {21, X, X*, 28} [max 28 from 3] - // - // Top scores are [28, 30, 20, 28] from [2, 3, 2, 3]. - // 2->OUT is not valid, so final scores are [28, 30, X*, 28] for a - // final state of [1] with a sequence of [3->1]. - // - - std::vector expected_transitions({2, 0, 3, 1}); - std::vector expected_offsets({0, 2, 4}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines multiple evaluations with both weight and permission -// matrices. -TEST_F(LogViterbiConstrainedSequenceTest, - ComputesMultipleTransitionsWithVaryingLengths) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({2, 2, 4}), // - {{ - 10.0, 12.0, 7.0, 4.0, // Batch 0, step 0 - 0.0, 0.0, 0.0, 0.0, // PAD - 7.0, 9.0, 11.0, 5.0, // Batch 1, step 0 - 10.0, 15.0, 1.0, 12.0, // Batch 1, step 1 - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({2}), {1, 2}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO NUL - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, false, true, false, false, // FROM 2 - true, true, true, true, true, // FROM 3 (OUT) - true, false, true, true, true, // FROM 'NULL' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({5, 5}), - {-1.0, 1.0, -2.0, 2.0, 0.0, // 0 - 3.0, -3.0, 4.0, -4.0, 0.0, // 1 - 5.0, -5.0, 6.0, -6.0, 0.0, // 2 - -7.0, 7.0, -8.0, 8.0, 0.0, // 3 - 0.0, 1.0, 2.0, 3.0, 0.0}); - - TF_ASSERT_OK(RunOpKernel()); - - // STEP 1: - // All scores should be summed with the last row in the weight tensor, so the - // 'real' scores are: - // B0: { 10.0, [NOTOK], 9.0, 7.0} - // B1: { 7.0, [NOTOK], 13.0, 8.0} - // - // STEP 2: - // (Forbidden transitions are marked with '*' and X stands for the lowest - // possible score.) - // - // BATCH 0: - // Batch 0 is complete. - // - // BATCH 1: - // Previous scores are {7, X, 13, 8} - // Raw scores are {10, 15, 1, 12} - // - // Final state 0: Weighted score is {9, 18, 15, 3} - // New totals are {16, X, 28, 11} [max 28 from 2] - // - // Final state 1: Weighted score is {16, 12, 10, 22} - // New totals are {23, X, X*, 30} [max 30 from 3] - // - // Final state 2: Weighted score is {-1, 5, 7, -7} - // New totals are {6, X, 20, 1} [max 20 from 2] - // - // Final state 3: Weighted score is {14, 8, 6, 20} - // New totals are {21, X, X*, 28} [max 28 from 3] - // - // Top scores are [28, 30, 20, 28] from [2, 3, 2, 3]. - // 2->OUT is not valid, so final scores are [28, 30, X*, 28] for a - // final state of [1] with a sequence of [3->1]. - // - - std::vector expected_transitions({0, 3, 1}); - std::vector expected_offsets({0, 1, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test examines evaluations with a fully negative input set. -TEST_F(LogViterbiConstrainedSequenceTest, - ComputesSingleTransitionWithNegativeInputs) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - -10.0, -12.0, -13.0, -4.0, // - -1.0, -12.0, -13.0, -14.0, // - -15.0, -2.0, -3.0, -14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 1, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, true, true, true, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - std::vector expected_transitions({3, 0, 1}); - std::vector expected_offsets({0, 1, 2, 3}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -TEST_F(LogViterbiConstrainedSequenceTest, - ImpossibleSequencesResultInNegativeOnesIfAttrIsSet) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 2, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {2, 2, 2}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - false, false, false, false, false, // FROM 0 - false, false, false, false, false, // FROM 1 - false, false, false, false, false, // FROM 2 - false, false, false, false, false, // FROM 3 - false, false, false, false, false, // FROM 'OUT' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - TF_ASSERT_OK(RunOpKernel()); - - // Validate the output. - - std::vector expected_transitions({-1, -1, -1, -1, -1, -1}); - std::vector expected_offsets({0, 2, 4, 6}); - - // Validate the output. - EXPECT_THAT(*GetOutput(0), VectorEq(expected_transitions)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_offsets)); -} - -// This test ensures the op will throw an error if there are too few scores to -// finalize all the sequences. -TEST_F(LogViterbiConstrainedSequenceTest, ErrorsIfGivenInsufficientScores) { - // Prepare graph. - SetUpOpWithDefaults(); - - // Add the scores input. - AddInputFromArray(TensorShape({3, 1, 4}), // - {{ - 10.0, 12.0, 13.0, 4.0, // - 1.0, 12.0, 13.0, 14.0, // - 15.0, 2.0, 3.0, 14.0, // - }}); - - // Add the sequence_lengths input. - AddInputFromArray(TensorShape({3}), {1, 2, 1}); - - // Add the allowed_transitions input. - AddInputFromArray(TensorShape({5, 5}), - { - // TO 0 TO 1 TO 2 TO 3 TO OUT - true, true, true, true, true, // FROM 0 - true, true, true, true, true, // FROM 1 - true, true, true, true, true, // FROM 2 - true, true, true, true, true, // FROM 3 - true, true, false, true, false, // FROM 'OUTSIDE' - }); - - // Add the transition_weights input. - AddInputFromArray(TensorShape({0, 0}), {}); - - auto result = RunOpKernel(); - EXPECT_FALSE(result.ok()); -} - -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/mst_op_kernels.cc b/tensorflow_text/core/kernels/mst_op_kernels.cc deleted file mode 100644 index 01eb5954b..000000000 --- a/tensorflow_text/core/kernels/mst_op_kernels.cc +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/util/work_sharder.h" -#include "tensorflow_text/core/kernels/mst_solver.h" - -namespace tensorflow { -namespace text { - -// Op kernel implementation that wraps the |MstSolver|. -template -class MaxSpanningTreeOpKernel : public tensorflow::OpKernel { - public: - explicit MaxSpanningTreeOpKernel(tensorflow::OpKernelConstruction *context) - : tensorflow::OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("forest", &forest_)); - } - - void Compute(tensorflow::OpKernelContext *context) override { - const tensorflow::Tensor &num_nodes_tensor = context->input(0); - const tensorflow::Tensor &scores_tensor = context->input(1); - - // Check ranks. - OP_REQUIRES(context, num_nodes_tensor.dims() == 1, - tensorflow::errors::InvalidArgument( - "num_nodes must be a vector, got shape ", - num_nodes_tensor.shape().DebugString())); - OP_REQUIRES(context, scores_tensor.dims() == 3, - tensorflow::errors::InvalidArgument( - "scores must be rank 3, got shape ", - scores_tensor.shape().DebugString())); - - // Batch size and input dimension (B and M in the op docstring). - const int64 batch_size = scores_tensor.shape().dim_size(0); - const int64 input_dim = scores_tensor.shape().dim_size(1); - - // Check shapes. - const tensorflow::TensorShape shape_b({batch_size}); - const tensorflow::TensorShape shape_bxm({batch_size, input_dim}); - const tensorflow::TensorShape shape_bxmxm( - {batch_size, input_dim, input_dim}); - OP_REQUIRES( - context, num_nodes_tensor.shape() == shape_b, - tensorflow::errors::InvalidArgument( - "num_nodes misshapen: got ", num_nodes_tensor.shape().DebugString(), - " but expected ", shape_b.DebugString())); - OP_REQUIRES( - context, scores_tensor.shape() == shape_bxmxm, - tensorflow::errors::InvalidArgument( - "scores misshapen: got ", scores_tensor.shape().DebugString(), - " but expected ", shape_bxmxm.DebugString())); - - // Create outputs. - tensorflow::Tensor *max_scores_tensor = nullptr; - tensorflow::Tensor *argmax_sources_tensor = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0, shape_b, &max_scores_tensor)); - OP_REQUIRES_OK(context, context->allocate_output(1, shape_bxm, - &argmax_sources_tensor)); - - // Acquire shaped and typed references. - const BatchedSizes num_nodes_b = num_nodes_tensor.vec(); - const BatchedScores scores_bxmxm = scores_tensor.tensor(); - BatchedMaxima max_scores_b = max_scores_tensor->vec(); - BatchedSources argmax_sources_bxm = argmax_sources_tensor->matrix(); - - // Solve the batch of MST problems in parallel. Set a high cycles per unit - // to encourage finer sharding. - constexpr int64 kCyclesPerUnit = 1000 * 1000 * 1000; - std::vector statuses(batch_size); - context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor( - batch_size, kCyclesPerUnit, [&](int64 begin, int64 end) { - for (int64 problem = begin; problem < end; ++problem) { - statuses[problem] = RunSolver(problem, num_nodes_b, scores_bxmxm, - max_scores_b, argmax_sources_bxm); - } - }); - for (const absl::Status &status : statuses) { - OP_REQUIRES_OK(context, status); - } - } - - private: - using BatchedSizes = typename tensorflow::TTypes::ConstVec; - using BatchedScores = typename tensorflow::TTypes::ConstTensor; - using BatchedMaxima = typename tensorflow::TTypes::Vec; - using BatchedSources = typename tensorflow::TTypes::Matrix; - - // Solves for the maximum spanning tree of the digraph defined by the values - // at index |problem| in |num_nodes_b| and |scores_bxmxm|. On success, sets - // the values at index |problem| in |max_scores_b| and |argmax_sources_bxm|. - // On error, returns non-OK. - absl::Status RunSolver(int problem, BatchedSizes num_nodes_b, - BatchedScores scores_bxmxm, BatchedMaxima max_scores_b, - BatchedSources argmax_sources_bxm) const { - // Check digraph size overflow. - const int32 num_nodes = num_nodes_b(problem); - const int32 input_dim = argmax_sources_bxm.dimension(1); - if (num_nodes > input_dim) { - return tensorflow::errors::InvalidArgument( - "number of nodes in digraph ", problem, - " overflows input dimension: got ", num_nodes, - " but expected <= ", input_dim); - } - if (num_nodes >= std::numeric_limits::max()) { - return tensorflow::errors::InvalidArgument( - "number of nodes in digraph ", problem, " overflows index type: got ", - num_nodes, " but expected < ", std::numeric_limits::max()); - } - const Index num_nodes_index = static_cast(num_nodes); - - MstSolver solver; - TF_RETURN_IF_ERROR(solver.Init(forest_, num_nodes_index)); - - // Populate the solver with arcs and root selections. Note that non-finite - // scores are treated as nonexistent arcs or roots. - for (Index target = 0; target < num_nodes_index; ++target) { - for (Index source = 0; source < num_nodes_index; ++source) { - const Score score = scores_bxmxm(problem, target, source); - if (!std::isfinite(static_cast(score))) continue; - if (source == target) { // root - solver.AddRoot(target, score); - } else { // arc - solver.AddArc(source, target, score); - } - } - } - - std::vector argmax(num_nodes); - TF_RETURN_IF_ERROR(solver.Solve(&argmax)); - - // Output the tree and accumulate its score. - Score max_score = 0; - for (Index target = 0; target < num_nodes_index; ++target) { - const Index source = argmax[target]; - argmax_sources_bxm(problem, target) = source; - max_score += scores_bxmxm(problem, target, source); - } - max_scores_b(problem) = max_score; - - // Pad the source list with -1. - for (int32 i = num_nodes; i < input_dim; ++i) { - argmax_sources_bxm(problem, i) = -1; - } - - return absl::OkStatus(); - } - - private: - bool forest_ = false; -}; - -// Use Index=uint16, which allows digraphs containing up to 32,767 nodes. -REGISTER_KERNEL_BUILDER(Name("MaxSpanningTree") - .Device(tensorflow::DEVICE_CPU) - .TypeConstraint("T"), - MaxSpanningTreeOpKernel); -REGISTER_KERNEL_BUILDER(Name("MaxSpanningTree") - .Device(tensorflow::DEVICE_CPU) - .TypeConstraint("T"), - MaxSpanningTreeOpKernel); -REGISTER_KERNEL_BUILDER(Name("MaxSpanningTree") - .Device(tensorflow::DEVICE_CPU) - .TypeConstraint("T"), - MaxSpanningTreeOpKernel); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/mst_solver.h b/tensorflow_text/core/kernels/mst_solver.h index b75e964bf..dbcc545ff 100644 --- a/tensorflow_text/core/kernels/mst_solver.h +++ b/tensorflow_text/core/kernels/mst_solver.h @@ -12,596 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_TEXT_CORE_KERNELS_MST_SOLVER_H_ -#define TENSORFLOW_TEXT_CORE_KERNELS_MST_SOLVER_H_ +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_MST_SOLVER_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_MST_SOLVER_H_ -#include +#include "tensorflow/core/kernels/text/mst_solver.h" -#include -#include -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow_text/core/kernels/disjoint_set_forest.h" - -namespace tensorflow { -namespace text { - -// Maximum spanning tree solver for directed graphs. Thread-compatible. -// -// The solver operates on a digraph of n nodes and m arcs and outputs a maximum -// spanning tree rooted at any node. Scores can be associated with arcs and -// root selections, and the score of a tree is the sum of the relevant arc and -// root-selection scores. -// -// The implementation is based on: -// -// go/tarjan-1977 google-only -// R.E. Tarjan. 1977. Finding Optimum Branchings. Networks 7(1), pp. 25-35. -// [In particular, see Section 4 "a modification for dense graphs"] -// -// which itself is an improvement of the Chu-Liu-Edmonds algorithm. Note also -// the correction in: -// -// go/camerini-1979 google-only -// P.M. Camerini, L. Fratta, F. Maffioli. 1979. A Note on Finding Optimum -// Branchings. Networks 9(4), pp. 309-312. -// -// The solver runs in O(n^2) time, which is optimal for dense digraphs but slow -// for sparse digraphs where O(m + n log n) can be achieved. The solver uses -// O(n^2) space to store the digraph, which is also optimal for dense digraphs. -// -// Although this algorithm has an inferior asymptotic runtime on sparse graphs, -// it avoids high-constant-overhead data structures like Fibonacci heaps, which -// are required in the asymptotically faster algorithms. Therefore, this solver -// may still be competitive on small sparse graphs. -// -// TODO(terrykoo): If we start running on large sparse graphs, implement the -// following, which runs in O(m + n log n): -// -// go/tarjan-1986 google-only -// H.N. Gabow, Z. Galil, T. Spencer, and R.E. Tarjan. 1986. Efficient -// algorithms for finding minimum spanning trees in undirected and directed -// graphs. Combinatorica, 6(2), pp. 109-122. -// -// Template args: -// Index: An unsigned integral type wide enough to hold 2n. -// Score: A signed arithmetic (integral or floating-point) type. -template -class MstSolver { - public: - static_assert(std::is_integral::value, "Index must be integral"); - static_assert(!std::is_signed::value, "Index must be unsigned"); - static_assert(std::is_arithmetic::value, "Score must be arithmetic"); - static_assert(std::is_signed::value, "Score must be signed"); - using IndexType = Index; - using ScoreType = Score; - - // Creates an empty solver. Call Init() before use. - MstSolver() = default; - - // Initializes this for a digraph with |num_nodes| nodes, or returns non-OK on - // error. Discards existing state; call AddArc() and AddRoot() to add arcs - // and root selections. If |forest| is true, then this solves for a maximum - // spanning forest (i.e., a set of disjoint trees that span the digraph). - absl::Status Init(bool forest, Index num_nodes); - - // Adds an arc from the |source| node to the |target| node with the |score|. - // The |source| and |target| must be distinct node indices in [0,n), and the - // |score| must be finite. Calling this multiple times on the same |source| - // and |target| overwrites the score instead of adding parallel arcs. - void AddArc(Index source, Index target, Score score); - - // As above, but adds a root selection for the |root| node with the |score|. - void AddRoot(Index root, Score score); - - // Returns the score of the arc from |source| to |target|, which must have - // been added by a previous call to AddArc(). - Score ArcScore(Index source, Index target) const; - - // Returns the score of selecting the |root|, which must have been added by a - // previous call to AddRoot(). - Score RootScore(Index root) const; - - // Populates |argmax| with the maximum directed spanning tree of the current - // digraph, or returns non-OK on error. The |argmax| array must contain at - // least n elements. On success, argmax[t] is the source of the arc directed - // into t, or t itself if t is a root. - // - // NB: If multiple spanning trees achieve the maximum score, |argmax| will be - // set to one of the maximal trees, but it is unspecified which one. - absl::Status Solve(absl::Span argmax); - - // Convience method - absl::Status Solve(std::vector *argmax) { - return Solve(absl::MakeSpan(argmax->data(), argmax->size())); - } - - private: - // Implementation notes: - // - // The solver does not operate on the "original" digraph as specified by the - // user, but a "transformed" digraph that differs as follows: - // - // * The transformed digraph adds an "artificial root" node at index 0 and - // offsets all original node indices by +1 to make room. For each root - // selection, the artificial root has one outbound arc directed into the - // candidate root that carries the root-selection score. The artificial - // root has no inbound arcs. - // - // * When solving for a spanning tree (i.e., when |forest_| is false), the - // outbound arcs of the artificial root are penalized to ensure that the - // artificial root has exactly one child. - // - // In the remainder of this file, all mentions of nodes, arcs, etc., refer to - // the transformed digraph unless otherwise specified. - // - // The algorithm is divided into two phases, the "contraction phase" and the - // "expansion phase". The contraction phase finds the arcs that make up the - // maximum spanning tree by applying a series of "contractions" which further - // modify the digraph. The expansion phase "expands" these modifications and - // recovers the maximum spanning tree in the original digraph. - // - // During the contraction phase, the algorithm selects the best inbound arc - // for each node. These arcs can form cycles, which are "contracted" by - // removing the cycle nodes and replacing them with a new contracted node. - // Since each contraction removes 2 or more cycle nodes and adds 1 contracted - // node, at most n-1 contractions will occur. (The digraph initially contains - // n+1 nodes, but one is the artificial root, which cannot form a cycle). - // - // When contracting a cycle, nodes are not explicitly removed and replaced. - // Instead, a contracted node is appended to the digraph and the cycle nodes - // are remapped to the contracted node, which implicitly removes and replaces - // the cycle. As a result, each contraction actually increases the size of - // the digraph, up to a maximum of 2n nodes. One advantage of adding and - // remapping nodes is that it is convenient to recover the argmax spanning - // tree during the expansion phase. - // - // Note that contractions can be nested, because the best inbound arc for a - // contracted node may itelf form a cycle. During the expansion phase, the - // algorithm picks a root of the hierarchy of contracted nodes, breaks the - // cycle it represents, and repeats until all cycles are broken. - - // Constants, as enums to avoid the need for static variable definitions. - enum Constants : Index { - // An index reserved for "null" values. - kNullIndex = std::numeric_limits::max(), - }; - - // A possibly-nonexistent arc in the digraph. - struct Arc { - // Creates a nonexistent arc. - Arc() = default; - - // Returns true if this arc exists. - bool Exists() const { return target != 0; } - - // Returns true if this is a root-selection arc. - bool IsRoot() const { return source == 0; } - - // Returns a string representation of this arc. - std::string DebugString() const { - if (!Exists()) return "[null]"; - if (IsRoot()) { - return absl::StrCat("[*->", target, "=", score, "]"); - } - return absl::StrCat("[", source, "->", target, "=", score, "]"); - } - - // Score of this arc. - Score score; - - // Source of this arc in the initial digraph. - Index source; - - // Target of this arc in the initial digraph, or 0 if this is nonexistent. - Index target = 0; - }; - - // Returns the index, in |arcs_|, of the arc from |source| to |target|. The - // |source| must be one of the initial n+1 nodes. - size_t ArcIndex(size_t source, size_t target) const; - - // Penalizes the root arc scores to ensure that this finds a tree, or does - // nothing if |forest_| is true. Must be called before ContractionPhase(). - void MaybePenalizeRootScoresForTree(); - - // Returns the maximum inbound arc of the |node|, or null if there is none. - const Arc *MaximumInboundArc(Index node) const; - - // Merges the inbound arcs of the |cycle_node| into the inbound arcs of the - // |contracted_node|. Arcs are merged as follows: - // * If the source and target of the arc belong to the same strongly-connected - // component, it is ignored. - // * If exactly one of the nodes had an arc from some source, then on exit the - // |contracted_node| has that arc. - // * If both of the nodes had an arc from the same source, then on exit the - // |contracted_node| has the better-scoring arc. - // The |score_offset| is added to the arc scores of the |cycle_node| before - // they are merged into the |contracted_node|. - void MergeInboundArcs(Index cycle_node, Score score_offset, - Index contracted_node); - - // Contracts the cycle in |argmax_arcs_| that contains the |node|. - void ContractCycle(Index node); - - // Runs the contraction phase of the solver, or returns non-OK on error. This - // phase finds the best inbound arc for each node, contracting cycles as they - // are formed. Stops when every node has selected an inbound arc and there - // are no cycles. - absl::Status ContractionPhase(); - - // Runs the expansion phase of the solver, or returns non-OK on error. This - // phase expands each contracted node, breaks cycles, and populates |argmax| - // with the maximum spanning tree. - absl::Status ExpansionPhase(absl::Span argmax); - - // If true, solve for a spanning forest instead of a spanning tree. - bool forest_ = false; - - // The number of nodes in the original digraph; i.e., n. - Index num_original_nodes_ = 0; - - // The number of nodes in the initial digraph; i.e., n+1. - Index num_initial_nodes_ = 0; - - // The maximum number of possible nodes in the digraph; i.e., 2n. - Index num_possible_nodes_ = 0; - - // The number of nodes in the current digraph, which grows from n+1 to 2n. - Index num_current_nodes_ = 0; - - // Column-major |num_initial_nodes_| x |num_current_nodes_| matrix of arcs, - // where rows and columns correspond to source and target nodes. Columns are - // added as cycles are contracted into new nodes. - // - // TODO(terrykoo): It is possible to squeeze the nonexistent arcs out of each - // column and run the algorithm with each column being a sorted list (sorted - // by source node). This is in fact the suggested representation in Tarjan - // (1977). This won't improve the asymptotic runtime but still might improve - // speed in practice. I haven't done this because it adds complexity versus - // checking Arc::Exists() in a few loops. Try this out when we can benchmark - // this on real data. - std::vector arcs_; - - // Disjoint-set forests tracking the weakly-connected and strongly-connected - // components of the initial digraph, based on the arcs in |argmax_arcs_|. - // Weakly-connected components are used to detect cycles; strongly-connected - // components are used to detect self-loops. - DisjointSetForest weak_components_; - DisjointSetForest strong_components_; - - // A disjoint-set forest that maps each node to the top-most contracted node - // that contains it. Nodes that have not been contracted map to themselves. - // NB: This disjoint-set forest does not use union by rank so we can control - // the outcome of a set union. There will only be O(n) operations on this - // instance, so the increased O(log n) cost of each operation is acceptable. - DisjointSetForest contracted_nodes_; - - // An array that represents the history of cycle contractions, as follows: - // * If contracted_into_[t] is |kNullIndex|, then t is deleted. - // * If contracted_into_[t] is 0, then t is a "root" contracted node; i.e., t - // has not been contracted into another node. - // * Otherwise, contracted_into_[t] is the node into which t was contracted. - std::vector contracted_into_; - - // The maximum inbound arc for each node. The first element is null because - // the artificial root has no inbound arcs. - std::vector argmax_arcs_; - - // Workspace for ContractCycle(), which records the nodes and arcs in the - // cycle being contracted. - std::vector> cycle_; -}; - -// Implementation details below. - -template -absl::Status MstSolver::Init(bool forest, Index num_nodes) { - if (num_nodes <= 0) { - return tensorflow::errors::InvalidArgument("Non-positive number of nodes: ", - num_nodes); - } - - // Upcast to size_t to avoid overflow. - if (2 * static_cast(num_nodes) >= static_cast(kNullIndex)) { - return tensorflow::errors::InvalidArgument("Too many nodes: ", num_nodes); - } - - forest_ = forest; - num_original_nodes_ = num_nodes; - num_initial_nodes_ = num_original_nodes_ + 1; - num_possible_nodes_ = 2 * num_original_nodes_; - num_current_nodes_ = num_initial_nodes_; - - // Allocate the full n+1 x 2n matrix, but start with a n+1 x n+1 prefix. - const size_t num_initial_arcs = static_cast(num_initial_nodes_) * - static_cast(num_initial_nodes_); - const size_t num_possible_arcs = static_cast(num_initial_nodes_) * - static_cast(num_possible_nodes_); - arcs_.reserve(num_possible_arcs); - arcs_.assign(num_initial_arcs, {}); - - weak_components_.Init(num_initial_nodes_); - strong_components_.Init(num_initial_nodes_); - contracted_nodes_.Init(num_possible_nodes_); - contracted_into_.assign(num_possible_nodes_, 0); - argmax_arcs_.assign(num_possible_nodes_, nullptr); - - // This doesn't need to be cleared now; it will be cleared before use. - cycle_.reserve(num_original_nodes_); - - return absl::OkStatus(); -} - -template -void MstSolver::AddArc(Index source, Index target, Score score) { - DCHECK_NE(source, target); - DCHECK(std::isfinite(score)); - Arc &arc = arcs_[ArcIndex(source + 1, target + 1)]; - arc.score = score; - arc.source = source + 1; - arc.target = target + 1; -} - -template -void MstSolver::AddRoot(Index root, Score score) { - DCHECK(std::isfinite(score)); - Arc &arc = arcs_[ArcIndex(0, root + 1)]; - arc.score = score; - arc.source = 0; - arc.target = root + 1; -} - -template -Score MstSolver::ArcScore(Index source, Index target) const { - const Arc &arc = arcs_[ArcIndex(source + 1, target + 1)]; - DCHECK(arc.Exists()); - return arc.score; -} - -template -Score MstSolver::RootScore(Index root) const { - const Arc &arc = arcs_[ArcIndex(0, root + 1)]; - DCHECK(arc.Exists()); - return arc.score; -} - -template -absl::Status MstSolver::Solve(absl::Span argmax) { - MaybePenalizeRootScoresForTree(); - TF_RETURN_IF_ERROR(ContractionPhase()); - TF_RETURN_IF_ERROR(ExpansionPhase(argmax)); - return absl::OkStatus(); -} - -template -inline size_t MstSolver::ArcIndex(size_t source, - size_t target) const { - DCHECK_LT(source, num_initial_nodes_); - DCHECK_LT(target, num_current_nodes_); - return source + target * static_cast(num_initial_nodes_); -} - -template -void MstSolver::MaybePenalizeRootScoresForTree() { - if (forest_) return; - DCHECK_EQ(num_current_nodes_, num_initial_nodes_) - << "Root penalties must be applied before starting the algorithm."; - - // Find the minimum and maximum arc scores. These allow us to bound the range - // of possible tree scores. - Score max_score = std::numeric_limits::lowest(); - Score min_score = std::numeric_limits::max(); - for (const Arc &arc : arcs_) { - if (!arc.Exists()) continue; - max_score = std::max(max_score, arc.score); - min_score = std::min(min_score, arc.score); - } - - // Nothing to do, no existing arcs. - if (max_score < min_score) return; - - // A spanning tree or forest contains n arcs. The penalty below ensures that - // every structure with one root has a higher score than every structure with - // two roots, and so on. - const Score root_penalty = 1 + num_initial_nodes_ * (max_score - min_score); - for (Index root = 1; root < num_initial_nodes_; ++root) { - Arc &arc = arcs_[ArcIndex(0, root)]; - if (!arc.Exists()) continue; - arc.score -= root_penalty; - } -} - -template -const typename MstSolver::Arc * -MstSolver::MaximumInboundArc(Index node) const { - const Arc *__restrict arc = &arcs_[ArcIndex(0, node)]; - const Arc *arc_end = arc + num_initial_nodes_; - - Score max_score = std::numeric_limits::lowest(); - const Arc *argmax_arc = nullptr; - for (; arc < arc_end; ++arc) { - if (!arc->Exists()) continue; - const Score score = arc->score; - if (max_score <= score) { - max_score = score; - argmax_arc = arc; - } - } - return argmax_arc; -} - -template -void MstSolver::MergeInboundArcs(Index cycle_node, - Score score_offset, - Index contracted_node) { - const Arc *__restrict cycle_arc = &arcs_[ArcIndex(0, cycle_node)]; - const Arc *cycle_arc_end = cycle_arc + num_initial_nodes_; - Arc *__restrict contracted_arc = &arcs_[ArcIndex(0, contracted_node)]; - - for (; cycle_arc < cycle_arc_end; ++cycle_arc, ++contracted_arc) { - if (!cycle_arc->Exists()) continue; // nothing to merge - - // Skip self-loops; they are useless because they cannot be used to break - // the cycle represented by the |contracted_node|. - if (strong_components_.SameSet(cycle_arc->source, cycle_arc->target)) { - continue; - } - - // Merge the |cycle_arc| into the |contracted_arc|. - const Score cycle_score = cycle_arc->score + score_offset; - if (!contracted_arc->Exists() || contracted_arc->score < cycle_score) { - contracted_arc->score = cycle_score; - contracted_arc->source = cycle_arc->source; - contracted_arc->target = cycle_arc->target; - } - } -} - -template -void MstSolver::ContractCycle(Index node) { - // Append a new node for the contracted cycle. - const Index contracted_node = num_current_nodes_++; - DCHECK_LE(num_current_nodes_, num_possible_nodes_); - arcs_.resize(arcs_.size() + num_initial_nodes_); - - // We make two passes through the cycle. The first pass updates everything - // except the |arcs_|, and the second pass updates the |arcs_|. The |arcs_| - // must be updated in a second pass because MergeInboundArcs() requires that - // the |strong_components_| are updated with the newly-contracted cycle. - cycle_.clear(); - Index cycle_node = node; - do { - // Gather the nodes and arcs in |cycle_| for the second pass. - const Arc *cycle_arc = argmax_arcs_[cycle_node]; - DCHECK(!cycle_arc->IsRoot()) << cycle_arc->DebugString(); - cycle_.emplace_back(cycle_node, cycle_arc); - - // Mark the cycle nodes as members of a strongly-connected component. - strong_components_.Union(cycle_arc->source, cycle_arc->target); - - // Mark the cycle nodes as members of the new contracted node. Juggling is - // required because |contracted_nodes_| also determines the next cycle node. - const Index next_node = contracted_nodes_.FindRoot(cycle_arc->source); - contracted_nodes_.UnionOfRoots(cycle_node, contracted_node); - contracted_into_[cycle_node] = contracted_node; - cycle_node = next_node; - - // When the cycle repeats, |cycle_node| will be equal to |contracted_node|, - // not |node|, because the first iteration of this loop mapped |node| to - // |contracted_node| in |contracted_nodes_|. - } while (cycle_node != contracted_node); - - // Merge the inbound arcs of each cycle node into the |contracted_node|. - for (const auto &node_and_arc : cycle_) { - // Set the |score_offset| to the cost of breaking the cycle by replacing the - // arc currently directed into the |cycle_node|. - const Index cycle_node = node_and_arc.first; - const Score score_offset = -node_and_arc.second->score; - MergeInboundArcs(cycle_node, score_offset, contracted_node); - } -} - -template -absl::Status MstSolver::ContractionPhase() { - // Skip the artificial root since it has no inbound arcs. - for (Index target = 1; target < num_current_nodes_; ++target) { - // Find the maximum inbound arc for the current |target|, if any. - const Arc *arc = MaximumInboundArc(target); - if (arc == nullptr) { - return tensorflow::errors::FailedPrecondition("Infeasible digraph"); - } - argmax_arcs_[target] = arc; - - // The articifial root cannot be part of a cycle, so we do not need to check - // for cycles or even update its membership in the connected components. - if (arc->IsRoot()) continue; - - // Since every node has at most one selected inbound arc, cycles can be - // detected using weakly-connected components. - const Index source_component = weak_components_.FindRoot(arc->source); - const Index target_component = weak_components_.FindRoot(arc->target); - if (source_component == target_component) { - // Cycle detected; contract it into a new node. - ContractCycle(target); - } else { - // No cycles, just update the weakly-connected components. - weak_components_.UnionOfRoots(source_component, target_component); - } - } - - return absl::OkStatus(); -} - -template -absl::Status MstSolver::ExpansionPhase(absl::Span argmax) { - if (argmax.size() < num_original_nodes_) { - return tensorflow::errors::InvalidArgument( - "Argmax array too small: ", num_original_nodes_, - " elements required, but got ", argmax.size()); - } - - // Select and expand a root contracted node until no contracted nodes remain. - // Thanks to the (topological) order in which contracted nodes are appended, - // root contracted nodes are easily enumerated using a backward scan. After - // this loop, entries [1,n] of |argmax_arcs_| provide the arcs of the maximum - // spanning tree. - for (Index i = num_current_nodes_ - 1; i >= num_initial_nodes_; --i) { - if (contracted_into_[i] == kNullIndex) continue; // already deleted - const Index root = i; // if not deleted, must be a root due to toposorting - - // Copy the cycle-breaking arc to its specified target. - const Arc *arc = argmax_arcs_[root]; - argmax_arcs_[arc->target] = arc; - - // The |arc| not only breaks the cycle associated with the |root|, but also - // breaks every nested cycle between the |root| and the target of the |arc|. - // Delete the contracted nodes corresponding to all broken cycles. - Index node = contracted_into_[arc->target]; - while (node != kNullIndex && node != root) { - const Index parent = contracted_into_[node]; - contracted_into_[node] = kNullIndex; - node = parent; - } - } - - // Copy the spanning tree from |argmax_arcs_| to |argmax|. Also count roots - // for validation below. - Index num_roots = 0; - for (Index target = 0; target < num_original_nodes_; ++target) { - const Arc &arc = *argmax_arcs_[target + 1]; - DCHECK_EQ(arc.target, target + 1) << arc.DebugString(); - if (arc.IsRoot()) { - ++num_roots; - argmax[target] = target; - } else { - argmax[target] = arc.source - 1; - } - } - DCHECK_GE(num_roots, 1); - - // Even when |forest_| is false, |num_roots| can still be more than 1. While - // the root score penalty discourages structures with multiple root arcs, it - // is not a hard constraint. For example, if the original digraph contained - // one root selection per node and no other arcs, the solver would incorrectly - // produce an all-root structure in spite of the root score penalty. As this - // example illustrates, however, |num_roots| will be more than 1 if and only - // if the original digraph is infeasible for trees. - if (!forest_ && num_roots != 1) { - return tensorflow::errors::FailedPrecondition("Infeasible digraph"); - } - - return absl::OkStatus(); -} - -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_TEXT_CORE_KERNELS_MST_SOLVER_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_MST_SOLVER_H_ diff --git a/tensorflow_text/core/kernels/mst_solver_random_comparison_test.cc b/tensorflow_text/core/kernels/mst_solver_random_comparison_test.cc deleted file mode 100644 index 9896801b5..000000000 --- a/tensorflow_text/core/kernels/mst_solver_random_comparison_test.cc +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include - -#include -#include -#include "absl/flags/flag.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow_text/core/kernels/mst_solver.h" -#include "tensorflow_text/core/kernels/spanning_tree_iterator.h" - -ABSL_FLAG(int64_t, seed, 0, - "Seed for random comparison tests, or 0 for a weak random seed."); -ABSL_FLAG(int, num_trials, 3, "Number of trials for random comparison tests."); - -namespace tensorflow { -namespace text { - -using ::testing::Contains; - -// Returns the random seed, or 0 for a weak random seed. -int64 GetSeed() { return absl::GetFlag(FLAGS_seed); } - -// Returns the number of trials to run for each random comparison. -int64 GetNumTrials() { return absl::GetFlag(FLAGS_num_trials); } - -// Testing rig. Runs a comparison between a brute-force MST solver and the -// MstSolver<> on random digraphs. When the first test parameter is true, -// solves for forests instead of trees. The second test parameter defines the -// size of the test digraph. -class MstSolverRandomComparisonTest - : public ::testing::TestWithParam<::testing::tuple> { - protected: - // Use integer scores so score comparisons are exact. - using Solver = MstSolver; - - // An array providing a source node for each node. Roots are self-loops. - using SourceList = SpanningTreeIterator::SourceList; - - // A row-major n x n matrix whose i,j entry gives the score of the arc from i - // to j, and whose i,i entry gives the score of selecting i as a root. - using ScoreMatrix = std::vector; - - // Returns true if this should be a forest. - bool forest() const { return ::testing::get<0>(GetParam()); } - - // Returns the number of nodes for digraphs. - uint32 num_nodes() const { return ::testing::get<1>(GetParam()); } - - // Returns the score of the arcs in |sources| based on the |scores|. - int32 ScoreArcs(const ScoreMatrix &scores, const SourceList &sources) const { - CHECK_EQ(num_nodes() * num_nodes(), scores.size()); - int32 score = 0; - for (uint32 target = 0; target < num_nodes(); ++target) { - const uint32 source = sources[target]; - score += scores[target + source * num_nodes()]; - } - return score; - } - - // Returns the score of the maximum spanning tree (or forest, if the first - // test parameter is true) of the dense digraph defined by the |scores|, and - // sets |argmax_trees| to contain all maximal trees. - int32 RunBruteForceMstSolver(const ScoreMatrix &scores, - std::set *argmax_trees) { - CHECK_EQ(num_nodes() * num_nodes(), scores.size()); - int32 max_score; - argmax_trees->clear(); - - iterator_.ForEachTree(num_nodes(), [&](const SourceList &sources) { - const int32 score = ScoreArcs(scores, sources); - if (argmax_trees->empty() || max_score < score) { - max_score = score; - argmax_trees->clear(); - argmax_trees->insert(sources); - } else if (max_score == score) { - argmax_trees->insert(sources); - } - }); - - return max_score; - } - - // As above, but uses the |solver_| and extracts only one |argmax_tree|. - int32 RunMstSolver(const ScoreMatrix &scores, SourceList *argmax_tree) { - CHECK_EQ(num_nodes() * num_nodes(), scores.size()); - TF_CHECK_OK(solver_.Init(forest(), num_nodes())); - - // Add all roots and arcs. - for (uint32 source = 0; source < num_nodes(); ++source) { - for (uint32 target = 0; target < num_nodes(); ++target) { - const int32 score = scores[target + source * num_nodes()]; - if (source == target) { - solver_.AddRoot(target, score); - } else { - solver_.AddArc(source, target, score); - } - } - } - - // Solve for the max spanning tree. - argmax_tree->resize(num_nodes()); - TF_CHECK_OK(solver_.Solve(argmax_tree)); - return ScoreArcs(scores, *argmax_tree); - } - - // Returns a random ScoreMatrix spanning num_nodes() nodes. - ScoreMatrix RandomScores() { - ScoreMatrix scores(num_nodes() * num_nodes()); - for (int32 &value : scores) value = static_cast(prng_() % 201) - 100; - return scores; - } - - // Runs a comparison between MstSolver and BruteForceMst on random digraphs of - // num_nodes() nodes, for the specified number of trials. - void RunComparison() { - // Seed the PRNG, possibly non-deterministically. Log the seed value so the - // test results can be reproduced, even when the seed is non-deterministic. - uint32 seed = GetSeed(); - if (seed == 0) seed = time(nullptr); - prng_.seed(seed); - LOG(INFO) << "seed = " << seed; - - const int num_trials = GetNumTrials(); - for (int trial = 0; trial < num_trials; ++trial) { - const ScoreMatrix scores = RandomScores(); - - std::set expected_argmax_trees; - const int32 expected_max_score = - RunBruteForceMstSolver(scores, &expected_argmax_trees); - - SourceList actual_argmax_tree; - const int32 actual_max_score = RunMstSolver(scores, &actual_argmax_tree); - - // In case of ties, MstSolver will find a maximal spanning tree, but we - // don't know which one. - EXPECT_EQ(expected_max_score, actual_max_score); - ASSERT_THAT(expected_argmax_trees, Contains(actual_argmax_tree)); - } - } - - // Tree iterator for brute-force solver. - SpanningTreeIterator iterator_{forest()}; - - // MstSolver<> instance used by the test. Reused across all MST invocations - // to exercise reuse. - Solver solver_; - - // Pseudo-random number generator. - std::mt19937 prng_; -}; - -INSTANTIATE_TEST_SUITE_P(AllowForest, MstSolverRandomComparisonTest, - ::testing::Combine(::testing::Bool(), - ::testing::Range(1, 9))); - -TEST_P(MstSolverRandomComparisonTest, Comparison) { RunComparison(); } - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/mst_solver_test.cc b/tensorflow_text/core/kernels/mst_solver_test.cc deleted file mode 100644 index 782f7817e..000000000 --- a/tensorflow_text/core/kernels/mst_solver_test.cc +++ /dev/null @@ -1,273 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/mst_solver.h" - -#include -#include -#include - -#include -#include -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace tensorflow { -namespace text { - -// Testing rig. -// -// Template args: -// Solver: An instantiation of the MstSolver<> template. -template -class MstSolverTest : public ::testing::Test { - protected: - using Index = typename Solver::IndexType; - using Score = typename Solver::ScoreType; - - // Adds directed arcs for all |num_nodes| nodes to the |solver_| with the - // |score|. - void AddAllArcs(Index num_nodes, Score score) { - for (Index source = 0; source < num_nodes; ++source) { - for (Index target = 0; target < num_nodes; ++target) { - if (source == target) continue; - solver_.AddArc(source, target, score); - } - } - } - - // Adds root selections for all |num_nodes| nodes to the |solver_| with the - // |score|. - void AddAllRoots(Index num_nodes, Score score) { - for (Index root = 0; root < num_nodes; ++root) { - solver_.AddRoot(root, score); - } - } - - // Runs the |solver_| using an argmax array of size |argmax_array_size| and - // expects it to fail with an error message that matches |error_substr|. - void SolveAndExpectError(int argmax_array_size, - const std::string &error_message_substr) { - std::vector argmax(argmax_array_size); - EXPECT_TRUE(absl::StrContains(solver_.Solve(&argmax).ToString(), - error_message_substr)); - } - - // As above, but expects success. Does not assert anything about the solution - // produced by the solver. - void SolveAndExpectOk(int argmax_array_size) { - std::vector argmax(argmax_array_size); - TF_EXPECT_OK(solver_.Solve(&argmax)); - } - - // As above, but expects the solution to be |expected_argmax| and infers the - // argmax array size. - void SolveAndExpectArgmax(const std::vector &expected_argmax) { - std::vector actual_argmax(expected_argmax.size()); - TF_ASSERT_OK(solver_.Solve(&actual_argmax)); - EXPECT_EQ(expected_argmax, actual_argmax); - } - - // MstSolver<> instance used by the test. Reused across all MST problems in - // each test to exercise reuse. - Solver solver_; -}; - -using Solvers = - ::testing::Types, MstSolver, - MstSolver, MstSolver, - MstSolver>; -TYPED_TEST_SUITE(MstSolverTest, Solvers); - -TYPED_TEST(MstSolverTest, FailIfNoNodes) { - for (const bool forest : {false, true}) { - EXPECT_TRUE(absl::StrContains(this->solver_.Init(forest, 0).ToString(), - "Non-positive number of nodes")); - } -} - -TYPED_TEST(MstSolverTest, FailIfTooManyNodes) { - // Set to a value that would overflow when doubled. - const auto kNumNodes = - (std::numeric_limits::max() / 2) + 10; - for (const bool forest : {false, true}) { - EXPECT_TRUE(absl::StrContains( - this->solver_.Init(forest, kNumNodes).ToString(), "Too many nodes")); - } -} - -TYPED_TEST(MstSolverTest, InfeasibleIfNoRootsNoArcs) { - const int kNumNodes = 10; - for (const bool forest : {false, true}) { - TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes)); - this->SolveAndExpectError(kNumNodes, "Infeasible digraph"); - } -} - -TYPED_TEST(MstSolverTest, InfeasibleIfNoRootsAllArcs) { - const int kNumNodes = 10; - for (const bool forest : {false, true}) { - TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes)); - this->AddAllArcs(kNumNodes, 0); - this->SolveAndExpectError(kNumNodes, "Infeasible digraph"); - } -} - -TYPED_TEST(MstSolverTest, FeasibleForForestOnlyIfAllRootsNoArcs) { - const int kNumNodes = 10; - for (const bool forest : {false, true}) { - TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes)); - this->AddAllRoots(kNumNodes, 0); - if (forest) { - this->SolveAndExpectOk(kNumNodes); // all roots is a valid forest - } else { - this->SolveAndExpectError(kNumNodes, "Infeasible digraph"); - } - } -} - -TYPED_TEST(MstSolverTest, FeasibleIfAllRootsAllArcs) { - const int kNumNodes = 10; - for (const bool forest : {false, true}) { - TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes)); - this->AddAllRoots(kNumNodes, 0); - this->AddAllArcs(kNumNodes, 0); - this->SolveAndExpectOk(kNumNodes); - } -} - -TYPED_TEST(MstSolverTest, FailIfArgmaxArrayTooSmall) { - const int kNumNodes = 10; - for (const bool forest : {false, true}) { - TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes)); - this->AddAllRoots(kNumNodes, 0); - this->AddAllArcs(kNumNodes, 0); - this->SolveAndExpectError(kNumNodes - 1, // too small - "Argmax array too small"); - } -} - -TYPED_TEST(MstSolverTest, OkIfArgmaxArrayTooLarge) { - const int kNumNodes = 10; - for (const bool forest : {false, true}) { - TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes)); - this->AddAllRoots(kNumNodes, 0); - this->AddAllArcs(kNumNodes, 0); - this->SolveAndExpectOk(kNumNodes + 1); // too large - } -} - -TYPED_TEST(MstSolverTest, SolveForAllRootsForestOnly) { - const int kNumNodes = 10; - const bool forest = true; - TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes)); - this->AddAllRoots(kNumNodes, 1); // favor all root selections - this->AddAllArcs(kNumNodes, 0); - this->SolveAndExpectArgmax({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); -} - -TYPED_TEST(MstSolverTest, SolveForLeftToRightChain) { - const int kNumNodes = 10; - for (const bool forest : {false, true}) { - TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes)); - this->AddAllRoots(kNumNodes, 0); - this->AddAllArcs(kNumNodes, 0); - for (int target = 1; target < kNumNodes; ++target) { - this->solver_.AddArc(target - 1, target, 1); // favor left-to-right chain - } - this->SolveAndExpectArgmax({0, 0, 1, 2, 3, 4, 5, 6, 7, 8}); - } -} - -TYPED_TEST(MstSolverTest, SolveForRightToLeftChain) { - const int kNumNodes = 10; - for (const bool forest : {false, true}) { - TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes)); - this->AddAllRoots(kNumNodes, 0); - this->AddAllArcs(kNumNodes, 0); - for (int source = 1; source < kNumNodes; ++source) { - this->solver_.AddArc(source, source - 1, 1); // favor right-to-left chain - } - this->SolveAndExpectArgmax({1, 2, 3, 4, 5, 6, 7, 8, 9, 9}); - } -} - -TYPED_TEST(MstSolverTest, SolveForAllFromFirstTree) { - const int kNumNodes = 10; - for (const bool forest : {false, true}) { - TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes)); - this->AddAllRoots(kNumNodes, 0); - this->AddAllArcs(kNumNodes, 0); - for (int target = 1; target < kNumNodes; ++target) { - this->solver_.AddArc(0, target, 1); // favor first -> target - } - this->SolveAndExpectArgmax({0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); - } -} - -TYPED_TEST(MstSolverTest, SolveForAllFromLastTree) { - const int kNumNodes = 10; - for (const bool forest : {false, true}) { - TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes)); - this->AddAllRoots(kNumNodes, 0); - this->AddAllArcs(kNumNodes, 0); - for (int target = 0; target + 1 < kNumNodes; ++target) { - this->solver_.AddArc(9, target, 1); // favor last -> target - } - this->SolveAndExpectArgmax({9, 9, 9, 9, 9, 9, 9, 9, 9, 9}); - } -} - -TYPED_TEST(MstSolverTest, SolveForBinaryTree) { - const int kNumNodes = 15; - for (const bool forest : {false, true}) { - TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes)); - this->AddAllRoots(kNumNodes, 0); - this->AddAllArcs(kNumNodes, 0); - for (int target = 1; target < kNumNodes; ++target) { - this->solver_.AddArc((target - 1) / 2, target, 1); // like a binary heap - } - // clang-format off - this->SolveAndExpectArgmax({0, - 0, 0, - 1, 1, 2, 2, - 3, 3, 4, 4, 5, 5, 6, 6}); - // clang-format on - } -} - -TYPED_TEST(MstSolverTest, ScoreAccessors) { - for (const bool forest : {false, true}) { - TF_ASSERT_OK(this->solver_.Init(forest, 10)); - this->solver_.AddArc(0, 1, 0); - this->solver_.AddArc(1, 4, 1); - this->solver_.AddArc(7, 6, 2); - this->solver_.AddArc(9, 2, 3); - - this->solver_.AddRoot(0, 10); - this->solver_.AddRoot(2, 20); - this->solver_.AddRoot(8, 30); - - EXPECT_EQ(this->solver_.ArcScore(0, 1), 0); - EXPECT_EQ(this->solver_.ArcScore(1, 4), 1); - EXPECT_EQ(this->solver_.ArcScore(7, 6), 2); - EXPECT_EQ(this->solver_.ArcScore(9, 2), 3); - - EXPECT_EQ(this->solver_.RootScore(0), 10); - EXPECT_EQ(this->solver_.RootScore(2), 20); - EXPECT_EQ(this->solver_.RootScore(8), 30); - } -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/ngrams_kernel.cc b/tensorflow_text/core/kernels/ngrams_kernel.cc deleted file mode 100644 index de9486e6f..000000000 --- a/tensorflow_text/core/kernels/ngrams_kernel.cc +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_text/core/kernels/ngrams_kernel.h" - -#include "tensorflow/core/framework/op_kernel.h" - -namespace tensorflow { -namespace text { - -REGISTER_KERNEL_BUILDER( - Name(NgramsStringJoinKernel::OpName()).Device(tensorflow::DEVICE_CPU), - NgramsStringJoinKernel); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/ngrams_kernel.h b/tensorflow_text/core/kernels/ngrams_kernel.h index e8c13d603..963b7f956 100644 --- a/tensorflow_text/core/kernels/ngrams_kernel.h +++ b/tensorflow_text/core/kernels/ngrams_kernel.h @@ -12,37 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_KERNEL_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_KERNEL_H_ -#include "tensorflow/lite/kernels/shim/tf_op_shim.h" -#include "tensorflow_text/core/kernels/ngrams_kernel_template.h" - -namespace tensorflow { -namespace text { - -class NgramsStringJoinKernel - : public tflite::shim::TfOpKernel { - public: - using TfOpKernel::TfOpKernel; -}; - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/ngrams_kernel.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_KERNEL_H_ diff --git a/tensorflow_text/core/kernels/ngrams_kernel_template.h b/tensorflow_text/core/kernels/ngrams_kernel_template.h index 1a8a3fc8f..0190a67a6 100644 --- a/tensorflow_text/core/kernels/ngrams_kernel_template.h +++ b/tensorflow_text/core/kernels/ngrams_kernel_template.h @@ -12,265 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_KERNEL_TEMPLATE_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_KERNEL_TEMPLATE_H_ -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +#include "tensorflow/core/kernels/text/ngrams_kernel_template.h" - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_KERNEL_TEMPLATE_H_ -#define TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_KERNEL_TEMPLATE_H_ - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/tstring.h" -#include "tensorflow/lite/kernels/shim/op_kernel.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow/lite/kernels/shim/tensor_view.h" - -namespace tensorflow { -namespace text { - -// text.ngrams op kernel. See `kDoc` for more info. -template -class NgramsStringJoin : public tflite::shim::OpKernelShim { - protected: - using Shape = tflite::shim::Shape; - - public: - using typename tflite::shim::OpKernelShim::InitContext; - using typename tflite::shim::OpKernelShim::InvokeContext; - using typename tflite::shim::OpKernelShim::ShapeInferenceContext; - - NgramsStringJoin() = default; - static constexpr char kOpName[] = "TFText>NgramsStringJoin"; - static constexpr char kDoc[] = R"doc( - Create a tensor of n-grams based on the string input data. - - Args: - input_values: A string tensor, or a ragged string tensor (a 1D string value - tensor and one or more 1D int64 row_split tensors). - row_splits: List of integer tensors representing the splits of the - input_values - width: scalar integer - The width of the ngram window. - axis: scalar integer - The axis to create ngrams along. Currently, it must be -1. - string_separator: scalar string - The separator string used to join tokens together. - - Returns: - output_values: A string tensor that matches the rank of 'data'. Will be a - ragged tensor if 'data' is a ragged tensor. - output_row_splits: Splits of above. - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration - static std::vector Attrs() { - return {"width: int", - "axis: int", - "string_separator: string", - "RAGGED_RANK: int >= 0", - "Tsplits: {int64} = DT_INT64"}; - } - // Input tensors declaration - static std::vector Inputs() { - return {"input_values: string", "input_row_splits: RAGGED_RANK * Tsplits"}; - } - // Output tensors declaration - static std::vector Outputs() { - return {"output_values: string", - "output_row_splits: RAGGED_RANK * Tsplits"}; - } - - // Initializes the op - absl::Status Init(InitContext* ctx) { - int64_t axis; - SH_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis)); - if (axis != -1) { - return absl::InternalError(absl::StrCat("axis != -1: ", axis)); - } - SH_RETURN_IF_ERROR(ctx->GetAttr("width", &width_)); - absl::string_view string_separator; - SH_RETURN_IF_ERROR(ctx->GetAttr("string_separator", &string_separator)); - string_separator_ = std::string(string_separator); - return absl::OkStatus(); - } - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* ctx) { - if (ctx->NumOutputs() == 1) { - // Tensor Output - SH_ASSIGN_OR_RETURN(const auto input_shape, ctx->GetInputShape(kValues)); - int64_t width; - SH_RETURN_IF_ERROR(ctx->GetAttr("width", &width)); - SH_RETURN_IF_ERROR(ctx->SetOutputShape( - kValues, OutputValuesTensorShape(input_shape, width))); - } else { - // RaggedTensor Output - SH_ASSIGN_OR_RETURN(const auto input_shape, ctx->GetInputShape(kValues)); - Shape output_shape(input_shape); - const int last_dim = output_shape->size() - 1; - if (last_dim != -1) { - (*output_shape)[last_dim] = output_shape.kUnknownDim; - } - SH_RETURN_IF_ERROR(ctx->SetOutputShape(kValues, output_shape)); - - // The row_splits tensors maintain their shape, because only the - // innermost dimension will change. - for (int i = kRowSplitsStart; i < ctx->NumOutputs(); ++i) { - SH_ASSIGN_OR_RETURN(const Shape input_row_splits_shape, - ctx->GetInputShape(i)); - if (input_row_splits_shape.Rank() != 1) { - return absl::InvalidArgumentError( - absl::StrCat("expected rank == 1 for input index: ", i)); - } - SH_RETURN_IF_ERROR(ctx->SetOutputShape(i, input_row_splits_shape)); - } - } - return absl::OkStatus(); - } - - // Runs the operation - absl::Status Invoke(InvokeContext* ctx) { - using Tsplits = int64_t; - // Storage for the dummy input and output row_splits used in the tensor - // case. - std::vector tensor_input_row_splits; - std::vector tensor_output_row_splits; - - const Tsplits* input_row_splits; - Tsplits* output_row_splits; - int n_row_splits = 0; - - SH_ASSIGN_OR_RETURN(const auto input_values, ctx->GetInput(kValues)); - const Shape input_values_shape(input_values->Shape()); - - // Tensor output - if (ctx->NumOutputs() == 1) { - // Generate mock input and output innermost row_splits. - int64_t total_tokens = - input_values->template Data().size(); - int64_t tokens_per_element = - input_values_shape->at(input_values_shape->size() - 1); - tensor_output_row_splits.resize(total_tokens / tokens_per_element + 1); - for (int64_t i = 0; i <= total_tokens; i += tokens_per_element) { - tensor_input_row_splits.push_back(i); - } - input_row_splits = tensor_input_row_splits.data(); - output_row_splits = tensor_output_row_splits.data(); - n_row_splits = tensor_input_row_splits.size(); - } else { - // RaggedTensor output - int index = 0; - const int num_row_splits = ctx->NumInputs() - kRowSplitsStart; - // Copy all input splits except for innermost into output splits. - while (index < num_row_splits - 1) { - SH_ASSIGN_OR_RETURN(const auto input_tensor_row_splits, - ctx->GetInput(kRowSplitsStart + index)); - SH_ASSIGN_OR_RETURN( - const auto output_tensor_row_splits, - ctx->GetOutput(kRowSplitsStart + index, - Shape(input_tensor_row_splits->Shape()))); - const auto input_buffer = - input_tensor_row_splits->template Data(); - const auto output_buffer = - output_tensor_row_splits->template Data(); - std::memcpy(output_buffer.data(), input_buffer.data(), - input_buffer.size() * sizeof(Tsplits)); - ++index; - } - // Set row splits variables to the innermost - SH_ASSIGN_OR_RETURN(const auto input_tensor_row_splits, - ctx->GetInput(kRowSplitsStart + index)); - SH_ASSIGN_OR_RETURN( - const auto output_tensor_row_splits, - ctx->GetOutput(kRowSplitsStart + index, - Shape(input_tensor_row_splits->Shape()))); - input_row_splits = - input_tensor_row_splits->template Data().data(); - output_row_splits = - output_tensor_row_splits->template Data().data(); - n_row_splits = input_tensor_row_splits->Shape().at(0); - } - - const auto input_values_data = - input_values->template Data(); - - // Create ngrams by looping through the innermost input splits. - std::vector buffer; - for (int i = 0; i < n_row_splits - 1; ++i) { - // Set output splits using current number of created output values. - output_row_splits[i] = buffer.size(); - std::vector tokens; - for (int j = input_row_splits[i]; j < input_row_splits[i + 1]; ++j) { - tokens.emplace_back(input_values_data.at(j)); - if (tokens.size() < width_) continue; - tokens.erase(tokens.begin(), tokens.begin() + tokens.size() - width_); - buffer.push_back(absl::StrJoin(tokens, string_separator_)); - } - } - output_row_splits[n_row_splits - 1] = buffer.size(); - - // Set output values from the generated buffer. - tflite::shim::TensorViewOr output_values_or; - if (ctx->NumOutputs() == 1) { - output_values_or = ctx->GetOutput( - kValues, OutputValuesTensorShape(input_values_shape, width_)); - } else { - output_values_or = - ctx->GetOutput(kValues, Shape({static_cast(buffer.size())})); - } - if (!output_values_or.ok()) return output_values_or.status(); - auto& output_buffer = - output_values_or.value()->template Data(); - int i = 0; - for (const auto& v : buffer) output_buffer[i++] = v; - return absl::OkStatus(); - } - - protected: - inline static Shape OutputValuesTensorShape(const Shape& input_values_shape, - const int64_t width) { - // If the input shape is unknown, so is the output shape. - if (input_values_shape.Rank() == input_values_shape.kUnknownRank) - return input_values_shape; - - Shape output_shape(input_values_shape); - const int last_dim = output_shape->size() - 1; - if (input_values_shape->at(last_dim) == input_values_shape.kUnknownDim) - return output_shape; - (*output_shape)[last_dim] = - std::max(0, output_shape->at(last_dim) - static_cast(width) + 1); - return output_shape; - } - - // Both the input and output tensors use the same indices. - static constexpr int kValues = 0; - static constexpr int kRowSplitsStart = 1; - - int64_t width_; - std::string string_separator_; -}; - -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_KERNEL_TEMPLATE_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_KERNEL_TEMPLATE_H_ diff --git a/tensorflow_text/core/kernels/ngrams_kernel_test.cc b/tensorflow_text/core/kernels/ngrams_kernel_test.cc deleted file mode 100644 index a6b70925c..000000000 --- a/tensorflow_text/core/kernels/ngrams_kernel_test.cc +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); - -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference_testutil.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { - -TEST(NgramsStringJoin, UnknownRank) { - ShapeInferenceTestOp op("TFText>NgramsStringJoin"); - op.input_tensors.resize(1); - AddNodeAttr("RAGGED_RANK", 0, &op.node_def); - AddNodeAttr("width", 1, &op.node_def); - - INFER_OK(op, "?", "?"); -} - -TEST(NgramsStringJoin, KnownRankUnknownDims) { - ShapeInferenceTestOp op("TFText>NgramsStringJoin"); - op.input_tensors.resize(1); - AddNodeAttr("RAGGED_RANK", 0, &op.node_def); - AddNodeAttr("width", 1, &op.node_def); - - INFER_OK(op, "[1,?]", "[1,?]"); -} - -TEST(NgramsStringJoin, LastDimWidth) { - ShapeInferenceTestOp op("TFText>NgramsStringJoin"); - op.input_tensors.resize(1); - AddNodeAttr("RAGGED_RANK", 0, &op.node_def); - AddNodeAttr("width", 3, &op.node_def); - - INFER_OK(op, "[?,5]", "[?,3]"); -} - -TEST(NgramsStringJoin, LastDimWidthClampZero) { - ShapeInferenceTestOp op("TFText>NgramsStringJoin"); - op.input_tensors.resize(1); - AddNodeAttr("RAGGED_RANK", 0, &op.node_def); - AddNodeAttr("width", 3, &op.node_def); - - INFER_OK(op, "[?,1]", "[?,0]"); -} - -} // end namespace tensorflow diff --git a/tensorflow_text/core/kernels/ngrams_tflite.cc b/tensorflow_text/core/kernels/ngrams_tflite.cc deleted file mode 100644 index 71c34bb2b..000000000 --- a/tensorflow_text/core/kernels/ngrams_tflite.cc +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_text/core/kernels/ngrams_tflite.h" - -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/shim/tflite_op_shim.h" -#include "tensorflow_text/core/kernels/ngrams_kernel_template.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -using OpKernel = - tflite::shim::TfLiteOpKernel; - -extern "C" void AddNgramsStringJoin(tflite::MutableOpResolver* resolver) { - OpKernel::Add(resolver); -} - -TfLiteRegistration* Register_TFText_NgramsStringJoin() { - return OpKernel::GetTfLiteRegistration(); -} - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow_text/core/kernels/ngrams_tflite.h b/tensorflow_text/core/kernels/ngrams_tflite.h index 0f0700ad0..02bfa93a8 100644 --- a/tensorflow_text/core/kernels/ngrams_tflite.h +++ b/tensorflow_text/core/kernels/ngrams_tflite.h @@ -15,39 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_TFLITE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_TFLITE_H_ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/mutable_op_resolver.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -// Adds the Ngrams custom op to an op resolver. -// This function can be loaded using dlopen. Since C++ function names get -// mangled, declare this function as extern C, so its name is unchanged. -extern "C" void AddNgramsStringJoin(MutableOpResolver* resolver); - -TfLiteRegistration* Register_TFText_NgramsStringJoin(); - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite +#include "tensorflow/core/kernels/text/ngrams_tflite.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_TFLITE_H_ diff --git a/tensorflow_text/core/kernels/ngrams_tflite_test.cc b/tensorflow_text/core/kernels/ngrams_tflite_test.cc deleted file mode 100644 index 0e2e88e61..000000000 --- a/tensorflow_text/core/kernels/ngrams_tflite_test.cc +++ /dev/null @@ -1,305 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_text/core/kernels/ngrams_tflite.h" - -#include -#include - -#include -#include -#include "flatbuffers/flexbuffers.h" -#include "tensorflow/lite/kernels/test_util.h" -#include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/string_util.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { -namespace { - -using ::testing::ElementsAre; -using ::testing::ElementsAreArray; - -class NgramsModel : public SingleOpModel { - public: - // Constructor for testing the op with a tf.Tensor - NgramsModel(int width, const std::string& string_separator, - const std::vector& input_values, - const std::vector& input_shape) { - input_values_ = AddInput(TensorType_STRING); - output_values_ = AddOutput(TensorType_STRING); - - BuildCustomOp(width, string_separator); - - BuildInterpreter({input_shape}); - PopulateStringTensor(input_values_, input_values); - Invoke(); - } - - // Constructor for the op with a tf.RaggedTensor - // Note: This interface uses row_lengths, as they're closer to the - // dimensions in a TensorShape, but internally everything is row_splits. - NgramsModel(int width, const std::string& string_separator, - const std::vector& input_values, - const std::vector> nested_row_lengths) { - std::vector> input_shapes; - input_shapes.reserve(nested_row_lengths.size() + 1); - - input_values_ = AddInput(TensorType_STRING); - input_shapes.push_back({static_cast(input_values.size())}); - output_values_ = AddOutput(TensorType_STRING); - - input_row_splits_.reserve(nested_row_lengths.size()); - output_row_splits_.reserve(nested_row_lengths.size()); - for (int i = 0; i < nested_row_lengths.size(); ++i) { - input_row_splits_.push_back(AddInput(TensorType_INT64)); - input_shapes.push_back( - {static_cast(nested_row_lengths[i].size() + 1)}); - output_row_splits_.push_back(AddOutput(TensorType_INT64)); - } - - BuildCustomOp(width, string_separator); - - BuildInterpreter(input_shapes); - PopulateStringTensor(input_values_, input_values); - for (int i = 0; i < nested_row_lengths.size(); ++i) { - std::vector row_splits; - row_splits.reserve(nested_row_lengths[i].size() + 1); - int64_t index = 0; - row_splits.push_back(index); - for (int64_t row_length : nested_row_lengths[i]) { - index += row_length; - row_splits.push_back(index); - } - PopulateTensor(input_row_splits_[i], row_splits); - } - Invoke(); - } - - std::vector GetValuesTensorShape() { - return GetTensorShape(output_values_); - } - - std::vector ExtractValuesTensorVector() { - std::vector r; - TfLiteTensor* tensor = interpreter_->tensor(output_values_); - int n = GetStringCount(tensor); - for (int i = 0; i < n; ++i) { - StringRef ref = GetString(tensor, i); - r.emplace_back(ref.str, ref.len); - } - return r; - } - - int GetNumNestedRowLengths() { return output_row_splits_.size(); } - - std::vector GetRowLengthsTensorShape(int i) { - std::vector shape = GetTensorShape(output_row_splits_[i]); - --shape[0]; - return shape; - } - - std::vector ExtractRowLengthsTensorVector(int i) { - std::vector row_splits = - ExtractVector(output_row_splits_[i]); - std::vector row_lengths; - row_lengths.reserve(row_splits.size() - 1); - int64_t head = row_splits[0]; - for (int i = 1; i < row_splits.size(); ++i) { - int64_t tail = row_splits[i]; - row_lengths.push_back(tail - head); - head = tail; - } - return row_lengths; - } - - private: - void BuildCustomOp(int width, const std::string& string_separator) { - flexbuffers::Builder fbb; - size_t start_map = fbb.StartMap(); - fbb.Int("width", width); - fbb.String("string_separator", string_separator); - fbb.Int("axis", -1); - fbb.String("reduction_type", "STRING_JOIN"); - fbb.EndMap(start_map); - fbb.Finish(); - - SetCustomOp("TFText>NgramsStringJoin", fbb.GetBuffer(), - Register_TFText_NgramsStringJoin); - } - - int input_values_; - std::vector input_row_splits_; - int output_values_; - std::vector output_row_splits_; -}; - -TEST(NgramsTest, TensorSingleSequenceWidthTwo) { - NgramsModel m(2, " ", {"this", "is", "a", "test"}, std::vector{4}); - EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3)); - EXPECT_THAT(m.ExtractValuesTensorVector(), - ElementsAre("this is", "is a", "a test")); -} - -TEST(NgramsTest, TensorSingleSequenceWidthThree) { - NgramsModel m(3, " ", {"this", "is", "a", "test"}, std::vector{4}); - EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(2)); - EXPECT_THAT(m.ExtractValuesTensorVector(), - ElementsAre("this is a", "is a test")); -} - -TEST(NgramsTest, TensorSingleSequenceLongerSeparator) { - NgramsModel m(2, "...", {"this", "is", "a", "test"}, std::vector{4}); - EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3)); - EXPECT_THAT(m.ExtractValuesTensorVector(), - ElementsAre("this...is", "is...a", "a...test")); -} - -TEST(NgramsTest, TensorSingleSequenceWidthTooLong) { - NgramsModel m(5, " ", {"this", "is", "a", "test"}, std::vector{4}); - EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(0)); - EXPECT_THAT(m.ExtractValuesTensorVector(), ElementsAre()); -} - -TEST(NgramsTest, TensorMultidimensionalInputWidthTwo) { - NgramsModel m(2, " ", - { - "0,0,0", "0,0,1", "0,0,2", "0,0,3", // - "0,1,0", "0,1,1", "0,1,2", "0,1,3", // - "0,2,0", "0,2,1", "0,2,2", "0,2,3", // - "1,0,0", "1,0,1", "1,0,2", "1,0,3", // - "1,1,0", "1,1,1", "1,1,2", "1,1,3", // - "1,2,0", "1,2,1", "1,2,2", "1,2,3", // - }, - std::vector{2, 3, 4}); - EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(2, 3, 3)); - EXPECT_THAT(m.ExtractValuesTensorVector(), - ElementsAreArray({ - "0,0,0 0,0,1", "0,0,1 0,0,2", "0,0,2 0,0,3", // - "0,1,0 0,1,1", "0,1,1 0,1,2", "0,1,2 0,1,3", // - "0,2,0 0,2,1", "0,2,1 0,2,2", "0,2,2 0,2,3", // - "1,0,0 1,0,1", "1,0,1 1,0,2", "1,0,2 1,0,3", // - "1,1,0 1,1,1", "1,1,1 1,1,2", "1,1,2 1,1,3", // - "1,2,0 1,2,1", "1,2,1 1,2,2", "1,2,2 1,2,3", // - })); -} - -TEST(NgramsTest, RaggedTensorSingleSequenceWidthTwo) { - std::vector> nested_row_lengths; - nested_row_lengths.push_back({4}); - NgramsModel m(2, " ", {"this", "is", "a", "test"}, - nested_row_lengths); - EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3)); - EXPECT_THAT(m.ExtractValuesTensorVector(), - ElementsAre("this is", "is a", "a test")); - ASSERT_THAT(m.GetNumNestedRowLengths(), 1); - EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(1)); - EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(3)); -} - -TEST(NgramsTest, RaggedTensorSingleSequenceWidthThree) { - std::vector> nested_row_lengths; - nested_row_lengths.push_back({4}); - NgramsModel m(3, " ", {"this", "is", "a", "test"}, nested_row_lengths); - EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(2)); - EXPECT_THAT(m.ExtractValuesTensorVector(), - ElementsAre("this is a", "is a test")); - ASSERT_THAT(m.GetNumNestedRowLengths(), 1); - EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(1)); - EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(2)); -} - -TEST(NgramsTest, RaggedTensorSingleSequenceLongerSeparator) { - std::vector> nested_row_lengths; - nested_row_lengths.push_back({4}); - NgramsModel m(2, "<>", {"this", "is", "a", "test"}, nested_row_lengths); - EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3)); - EXPECT_THAT(m.ExtractValuesTensorVector(), - ElementsAre("this<>is", "is<>a", "a<>test")); - ASSERT_THAT(m.GetNumNestedRowLengths(), 1); - EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(1)); - EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(3)); -} - -TEST(NgramsTest, RaggedTensorSingleSequenceWidthTooLong) { - std::vector> nested_row_lengths; - nested_row_lengths.push_back({4}); - NgramsModel m(5, " ", {"this", "is", "a", "test"}, nested_row_lengths); - EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(0)); - EXPECT_THAT(m.ExtractValuesTensorVector(), ElementsAre()); - ASSERT_THAT(m.GetNumNestedRowLengths(), 1); - EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(1)); - EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(0)); -} - -TEST(NgramsTest, RaggedTensorMultidimensionalInputWidthTwo) { - std::vector> nested_row_lengths; - nested_row_lengths.push_back({4, 2, 1}); - nested_row_lengths.push_back({5, 4, 3, 2, 2, 3, 4, 6}); - NgramsModel m(2, " ", - { - "0,0,0", "0,0,1", "0,0,2", "0,0,3", "0,0,4", // - "0,1,0", "0,1,1", "0,1,2", "0,1,3", // - "0,2,0", "0,2,1", "0,2,2", // - "0,3,0", "0,3,1", // - "1,0,0", "1,0,1", // - "1,1,0", "1,1,1", "1,1,2", // - "1,2,0", "1,2,1", "1,2,2", "1,2,3", // - "2,0,0", "2,0,1", "2,0,2", "2,0,3", "2,0,4", "2,0,5", // - }, - nested_row_lengths); - - std::vector expected_values = { - "0,0,0 0,0,1", "0,0,1 0,0,2", "0,0,2 0,0,3", "0,0,3 0,0,4", // - "0,1,0 0,1,1", "0,1,1 0,1,2", "0,1,2 0,1,3", // - "0,2,0 0,2,1", "0,2,1 0,2,2", // - "0,3,0 0,3,1", // - "1,0,0 1,0,1", // - "1,1,0 1,1,1", "1,1,1 1,1,2", // - "1,2,0 1,2,1", "1,2,1 1,2,2", "1,2,2 1,2,3", // - "2,0,0 2,0,1", "2,0,1 2,0,2", "2,0,2 2,0,3", "2,0,3 2,0,4", - "2,0,4 2,0,5", // - }; - EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(expected_values.size())); - EXPECT_THAT(m.ExtractValuesTensorVector(), ElementsAreArray(expected_values)); - ASSERT_THAT(m.GetNumNestedRowLengths(), 2); - EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(3)); - EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(4, 2, 1)); - EXPECT_THAT(m.GetRowLengthsTensorShape(1), ElementsAre(8)); - EXPECT_THAT(m.ExtractRowLengthsTensorVector(1), - ElementsAre(4, 3, 2, 1, 1, 2, 3, 5)); -} - -} // namespace -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow_text/core/kernels/normalize_kernels.cc b/tensorflow_text/core/kernels/normalize_kernels.cc deleted file mode 100644 index e011a4629..000000000 --- a/tensorflow_text/core/kernels/normalize_kernels.cc +++ /dev/null @@ -1,370 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include "absl/strings/ascii.h" -#include "absl/strings/str_cat.h" -#include "icu4c/source/common/unicode/edits.h" -#include "icu4c/source/common/unicode/errorcode.h" -#include "icu4c/source/common/unicode/normalizer2.h" -#include "icu4c/source/common/unicode/utypes.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/variant.h" -#include "tensorflow/core/framework/variant_encode_decode.h" -#include "tensorflow_text/core/kernels/edit_changes.pb.h" - -namespace tensorflow { -namespace text { - -class CaseFoldUTF8Op : public tensorflow::OpKernel { - public: - explicit CaseFoldUTF8Op(tensorflow::OpKernelConstruction* context) - : tensorflow::OpKernel(context) {} - - void Compute(tensorflow::OpKernelContext* context) override { - const tensorflow::Tensor* input_tensor; - OP_REQUIRES_OK(context, context->input("input", &input_tensor)); - const auto& input_vec = input_tensor->flat(); - - // TODO(gregbillock): support forwarding - tensorflow::Tensor* output_tensor; - OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor->shape(), - &output_tensor)); - auto output_vec = output_tensor->flat(); - - icu::ErrorCode icu_error; - const icu::Normalizer2* nfkc_cf = - icu::Normalizer2::getNFKCCasefoldInstance(icu_error); - OP_REQUIRES(context, icu_error.isSuccess(), - errors::Internal(absl::StrCat( - icu_error.errorName(), - ": Could not retrieve ICU NFKC_CaseFold normalizer"))); - - for (int64 i = 0; i < input_vec.size(); ++i) { - string output_text; - icu::StringByteSink byte_sink(&output_text); - const auto& input = input_vec(i); - nfkc_cf->normalizeUTF8(0, icu::StringPiece(input.data(), input.size()), - byte_sink, nullptr, icu_error); - OP_REQUIRES(context, icu_error.isSuccess(), - errors::Internal("Could not normalize input string: " + - input_vec(i))); - output_vec(i) = output_text; - } - } -}; - -REGISTER_KERNEL_BUILDER(Name("CaseFoldUTF8").Device(tensorflow::DEVICE_CPU), - CaseFoldUTF8Op); - -namespace { - -string GetNormalizationForm(OpKernelConstruction* context) { - string normalization_form; - ([=](string* c) -> void { - OP_REQUIRES_OK(context, context->GetAttr("normalization_form", c)); - })(&normalization_form); - return absl::AsciiStrToUpper(normalization_form); -} - -} // namespace - -class NormalizeUTF8Op : public tensorflow::OpKernel { - public: - explicit NormalizeUTF8Op(tensorflow::OpKernelConstruction* context) - : tensorflow::OpKernel(context), - normalization_form_(GetNormalizationForm(context)) {} - - void Compute(tensorflow::OpKernelContext* context) override { - const tensorflow::Tensor* input_tensor; - OP_REQUIRES_OK(context, context->input("input", &input_tensor)); - const auto& input_vec = input_tensor->flat(); - - tensorflow::Tensor* output_tensor; - OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor->shape(), - &output_tensor)); - auto output_vec = output_tensor->flat(); - - icu::ErrorCode icu_error; - const icu::Normalizer2* normalizer = nullptr; - if (normalization_form_ == "NFKC") { - normalizer = icu::Normalizer2::getNFKCInstance(icu_error); - OP_REQUIRES(context, icu_error.isSuccess(), - errors::Internal(absl::StrCat( - icu_error.errorName(), - ": Could not retrieve ICU NFKC normalizer"))); - } else if (normalization_form_ == "NFC") { - normalizer = icu::Normalizer2::getNFCInstance(icu_error); - OP_REQUIRES(context, icu_error.isSuccess(), - errors::Internal( - absl::StrCat(icu_error.errorName(), - ": Could not retrieve ICU NFC normalizer"))); - } else if (normalization_form_ == "NFD") { - normalizer = icu::Normalizer2::getNFDInstance(icu_error); - OP_REQUIRES(context, icu_error.isSuccess(), - errors::Internal( - absl::StrCat(icu_error.errorName(), - ": Could not retrieve ICU NFD normalizer"))); - } else if (normalization_form_ == "NFKD") { - normalizer = icu::Normalizer2::getNFKDInstance(icu_error); - OP_REQUIRES(context, icu_error.isSuccess(), - errors::Internal(absl::StrCat( - icu_error.errorName(), - ": Could not retrieve ICU NFKd normalizer"))); - } else { - OP_REQUIRES( - context, false, - errors::InvalidArgument(absl::StrCat( - "Unknown normalization form requrested: ", normalization_form_))); - } - - for (int64 i = 0; i < input_vec.size(); ++i) { - string output_text; - icu::StringByteSink byte_sink(&output_text); - const auto& input = input_vec(i); - normalizer->normalizeUTF8(0, icu::StringPiece(input.data(), input.size()), - byte_sink, nullptr, icu_error); - OP_REQUIRES( - context, icu_error.isSuccess(), - errors::Internal(absl::StrCat(icu_error.errorName(), - ": Could not normalize input string: ", - absl::string_view(input_vec(i))))); - output_vec(i) = output_text; - } - } - - private: - string normalization_form_; -}; - -REGISTER_KERNEL_BUILDER(Name("NormalizeUTF8").Device(tensorflow::DEVICE_CPU), - NormalizeUTF8Op); - -namespace { - -// OffsetMapVariant is a tf.Variant object that stores a single icu::Edits -// object and providing encode/decode methods. -// The encode method is called to serialize the stored icu::Edits object when -// the variant is assigned to graph output. The decode method is called to -// reconstruct the icu::Edits object from the serialized `changes` string when -// the variant is at the graph input. -struct OffsetMapVariant { - string changes; - icu::Edits edits_; - - std::string TypeName() const { return "(anonymous)::OffsetMapVariant"; } - void Encode(tensorflow::VariantTensorData* data) const; - bool Decode(const tensorflow::VariantTensorData& data); -}; - -void OffsetMapVariant::Encode(tensorflow::VariantTensorData* data) const { - EditChanges changes; - icu::Edits::Iterator it = edits_.getFineIterator(); - icu::ErrorCode icu_error; - while (it.next(icu_error)) { - auto* change = changes.add_change(); - change->set_old_length(it.oldLength()); - change->set_new_length(it.newLength()); - } - string changes_str = changes.SerializeAsString(); - data->set_metadata(changes_str); -} - -bool OffsetMapVariant::Decode(const tensorflow::VariantTensorData& data) { - string serialized; - data.get_metadata(&serialized); - EditChanges changes; - changes.ParseFromString(serialized); - icu::Edits edit; - icu::ErrorCode icu_error; - for (int64 j = 0; j < changes.change_size(); ++j) { - auto* change = changes.mutable_change(j); - int old_length = change->old_length(); - int new_length = change->new_length(); - if (old_length == new_length) { - edit.addUnchanged(static_cast(old_length)); - } else { - edit.addReplace(static_cast(old_length), - static_cast(new_length)); - } - } - edits_ = edit; - return true; -} -} // namespace - -class NormalizeUTF8WithOffsetsMapOp : public tensorflow::OpKernel { - public: - explicit NormalizeUTF8WithOffsetsMapOp( - tensorflow::OpKernelConstruction* context) - : tensorflow::OpKernel(context), - normalization_form_(GetNormalizationForm(context)) {} - - void Compute(tensorflow::OpKernelContext* context) override { - const tensorflow::Tensor* input_tensor; - OP_REQUIRES_OK(context, context->input("input", &input_tensor)); - const auto& input_vec = input_tensor->flat(); - - tensorflow::Tensor* output_tensor; - OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor->shape(), - &output_tensor)); - tensorflow::Tensor* output_offsets_map_tensor; - OP_REQUIRES_OK(context, - context->allocate_output(1, input_tensor->shape(), - &output_offsets_map_tensor)); - - auto output_vec = output_tensor->flat(); - auto output_offsets_map_vec = output_offsets_map_tensor->flat(); - - icu::ErrorCode icu_error; - const icu::Normalizer2* normalizer = nullptr; - if (normalization_form_ == "NFKC") { - normalizer = icu::Normalizer2::getNFKCInstance(icu_error); - OP_REQUIRES(context, icu_error.isSuccess(), - errors::Internal(absl::StrCat( - icu_error.errorName(), - ": Could not retrieve ICU NFKC normalizer"))); - } else if (normalization_form_ == "NFC") { - normalizer = icu::Normalizer2::getNFCInstance(icu_error); - OP_REQUIRES(context, icu_error.isSuccess(), - errors::Internal( - absl::StrCat(icu_error.errorName(), - ": Could not retrieve ICU NFC normalizer"))); - } else if (normalization_form_ == "NFD") { - normalizer = icu::Normalizer2::getNFDInstance(icu_error); - OP_REQUIRES(context, icu_error.isSuccess(), - errors::Internal( - absl::StrCat(icu_error.errorName(), - ": Could not retrieve ICU NFD normalizer"))); - } else if (normalization_form_ == "NFKD") { - normalizer = icu::Normalizer2::getNFKDInstance(icu_error); - OP_REQUIRES(context, icu_error.isSuccess(), - errors::Internal(absl::StrCat( - icu_error.errorName(), - ": Could not retrieve ICU NFKD normalizer"))); - } else { - OP_REQUIRES(context, false, - errors::InvalidArgument(absl::StrCat( - "Offset not supported for this normalization form: ", - normalization_form_))); - } - - for (int64 i = 0; i < input_vec.size(); ++i) { - OffsetMapVariant variant; - string output_text; - icu::Edits edits; - icu::StringByteSink byte_sink(&output_text); - const auto& input = input_vec(i); - normalizer->normalizeUTF8(0, icu::StringPiece(input.data(), input.size()), - byte_sink, &edits, icu_error); - OP_REQUIRES( - context, icu_error.isSuccess(), - errors::Internal(absl::StrCat(icu_error.errorName(), - ": Could not normalize input string: ", - absl::string_view(input_vec(i))))); - - output_vec(i) = output_text; - variant.edits_ = std::move(edits); - output_offsets_map_vec(i) = variant; - } - } - - private: - string normalization_form_; -}; - -REGISTER_KERNEL_BUILDER( - Name("NormalizeUTF8WithOffsetsMap").Device(tensorflow::DEVICE_CPU), - NormalizeUTF8WithOffsetsMapOp); - -template -class FindSourceOffsetsOp : public tensorflow::OpKernel { - public: - explicit FindSourceOffsetsOp(tensorflow::OpKernelConstruction* context) - : tensorflow::OpKernel(context) {} - - void Compute(tensorflow::OpKernelContext* context) override { - const tensorflow::Tensor& edits_values = context->input(0); - const tensorflow::Tensor& input_offsets_values = context->input(1); - const tensorflow::Tensor& input_offsets_splits = context->input(2); - - const auto& input_offsets_values_vec = input_offsets_values.flat(); - const auto& input_offsets_splits_vec = - input_offsets_splits.flat(); - const auto& edits_vec = edits_values.flat(); - - icu::ErrorCode icu_error; - int64 cur_split_index_begin = 0; - int64 cur_split_index_end = 0; - std::vector output_offsets_values(input_offsets_values_vec.size()); - int64 idx_edits = 0; - int64 idx_output = 0; - for (int64 i = 0; i < input_offsets_splits_vec.size() - 1; ++i) { - cur_split_index_begin = input_offsets_splits_vec(i); - cur_split_index_end = input_offsets_splits_vec(i + 1); - if (cur_split_index_begin == cur_split_index_end) { - continue; - } - OP_REQUIRES(context, idx_edits < edits_vec.size(), - tensorflow::errors::InvalidArgument( - "Input offset tensor dimension did not match the offset " - "map dimension.")); - auto iter = edits_vec(idx_edits++) - .get() - ->edits_.getFineChangesIterator(); - for (int64 j = cur_split_index_begin; j < cur_split_index_end; ++j) { - output_offsets_values[idx_output++] = - iter.sourceIndexFromDestinationIndex(input_offsets_values_vec(j), - icu_error); - } - } - OP_REQUIRES(context, idx_edits == edits_vec.size(), - tensorflow::errors::InvalidArgument( - "Input offset tensor dimension did not match the offset " - "map dimension.")); - - int64 output_offsets_values_size = output_offsets_values.size(); - Tensor* output_offsets_values_tensor = nullptr; - OP_REQUIRES_OK(context, context->allocate_output( - "output_offsets_values", - TensorShape({output_offsets_values_size}), - &output_offsets_values_tensor)); - auto output_offsets_values_data = - output_offsets_values_tensor->flat().data(); - memcpy(output_offsets_values_data, output_offsets_values.data(), - output_offsets_values_size * sizeof(int64)); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(FindSourceOffsetsOp); -}; - -REGISTER_KERNEL_BUILDER(Name("FindSourceOffsets") - .Device(tensorflow::DEVICE_CPU) - .TypeConstraint("Tsplits"), - FindSourceOffsetsOp); -REGISTER_KERNEL_BUILDER(Name("FindSourceOffsets") - .Device(tensorflow::DEVICE_CPU) - .TypeConstraint("Tsplits"), - FindSourceOffsetsOp); -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/normalize_kernels_test.cc b/tensorflow_text/core/kernels/normalize_kernels_test.cc deleted file mode 100644 index a3aa0207b..000000000 --- a/tensorflow_text/core/kernels/normalize_kernels_test.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. diff --git a/tensorflow_text/core/kernels/phrase_tokenizer.cc b/tensorflow_text/core/kernels/phrase_tokenizer.cc deleted file mode 100644 index cfffe87fe..000000000 --- a/tensorflow_text/core/kernels/phrase_tokenizer.cc +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/phrase_tokenizer.h" - -#include -#include -#include -#include -#include - -#include "absl/strings/match.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow_text/core/kernels/whitespace_tokenizer_config_builder.h" - -namespace tensorflow { -namespace text { - -/*static*/ absl::StatusOr PhraseTokenizer::Create( - const void* config_flatbuffer) { - PhraseTokenizer tokenizer; - // `GetPhraseTokenizerConfig()` is autogenerated by flatbuffer. - tokenizer.phrase_config_ = GetPhraseTokenizerConfig(config_flatbuffer); - tokenizer.trie_ = absl::make_unique( - tokenizer.phrase_config_->vocab_trie()->nodes()); - tokenizer.prob_ = static_cast(tokenizer.phrase_config_->prob()) / 100; - const auto& ws_config = tokenizer.phrase_config_->whitespace_config(); - tokenizer.whitespace_config_str_ = - absl::string_view(ws_config->c_str(), ws_config->size()); - tokenizer.whitespace_tokenizer_ = absl::make_unique( - WhitespaceTokenizerConfig(tokenizer.whitespace_config_str_)); - tokenizer.split_end_punctuation_ = - tokenizer.phrase_config_->split_end_punctuation(); - return std::move(tokenizer); -} - -void PhraseTokenizer::Tokenize(const absl::string_view input, - std::vector* result_tokens, - std::vector* result_token_ids) { - // Word level information. - std::vector tokens; - - whitespace_tokenizer_->Tokenize(input, &tokens); - - // Loop through tokens, considering 1-level punctuations. - std::string all_str; - int n = tokens.size(); - for (int i = 0; i < n; i++) { - if (tokens[i].empty()) { - continue; - } - if (split_end_punctuation_) { - bool contained_special_token = false; - for (const auto& special_token : special_tokens_) { - if (absl::EndsWith(tokens[i], special_token)) { - // Eg: split "can't" into "can 't" - all_str += - tokens[i].substr(0, tokens[i].size() - special_token.size()); - all_str += " "; - all_str += special_token; - contained_special_token = true; - break; - } - } - if (!contained_special_token) { - all_str += tokens[i]; - } - } else { - all_str += tokens[i]; - } - if (i < n - 1) { - all_str += " "; - } - } - - FindPhraseTokens(all_str, result_tokens, result_token_ids); -} - -void PhraseTokenizer::FindPhraseTokens(const std::string& cur_phrase, - std::vector* phrase_tokens, - std::vector* phrase_token_ids) { - // Do a simple left to right search to tokenize the input text. - int index = 0; - while (index < cur_phrase.size()) { - bool in_trie = false; - int token_id = phrase_config_->unk_token_id(); - int length = 0; - PhraseLookup(cur_phrase, index, &in_trie, &token_id, &length); - if (!in_trie) { - // fall back to using single token. - std::size_t found = cur_phrase.find_first_of(' ', index); - phrase_tokens->push_back(phrase_config_->unk_token()->str()); - phrase_token_ids->push_back(phrase_config_->unk_token_id()); - if (found == std::string::npos) { - break; - } - index = found + 1; - } else { - // Found a phrase. - phrase_tokens->push_back(cur_phrase.substr(index, length)); - phrase_token_ids->push_back(token_id); - index += (length + 1); - } - } -} - -void PhraseTokenizer::PhraseLookup(const std::string& token, int cur, - bool* in_trie, int* emitted_phrase_id, - int* emitted_phrase_length) { - int matched_phrase_id = -1; - int matched_phrase_length = 0; - bool phrase_emitted = false; - float prob = prob_; - absl::BitGen* gen = &gen_; - auto phrase_emit_func = - [&token /*the input string*/, - cur /*the current starting point for searching phrase*/, - prob /*the probability to emit the current found phrase*/, - in_trie /*whether a phrase in matched in the trie*/, - emitted_phrase_id /*the token id of the emitted phrase*/, - emitted_phrase_length /*the length of the emitted phrase*/, - &matched_phrase_id /*the token id of the matched phrase*/, - &matched_phrase_length /*the length of the matched phrase*/, - &phrase_emitted /*whether the phrase is emitted or not*/, - gen /*the random generator*/]( - const sentencepiece::DoubleArrayTrie::Match& m) { - if (phrase_emitted || (cur + m.match_length < token.size() && - token[cur + m.match_length] != ' ')) { - // We should continue search without going through this function if: - // 1: a phrase has already been emitted, or - // 2: We located a phrase that split one single word. - return; - } - - matched_phrase_id = m.id; - matched_phrase_length = m.match_length; - *in_trie = true; - if ((prob > 0) && absl::Bernoulli(*gen, prob)) { - // Emit the current phrase. - *emitted_phrase_id = m.id; - *emitted_phrase_length = m.match_length; - phrase_emitted = true; - } - }; - trie_->IteratePrefixMatches( - sentencepiece::utils::string_view(token.data() + cur, token.size() - cur), - phrase_emit_func); - if (*in_trie && !phrase_emitted) { - // We should use prev longest one as output as we prefer longer ones. - *emitted_phrase_id = matched_phrase_id; - *emitted_phrase_length = matched_phrase_length; - } -} - -absl::StatusOr> PhraseTokenizer::DetokenizeToTokens( - const absl::Span input) const { - std::vector output_tokens; - if (!phrase_config_->support_detokenization()) { - return absl::FailedPreconditionError( - "Detokenize function is only enabled when support_detokenization is " - "true in the config flatbuffer. Please rebuild the model flatbuffer " - "by setting support_detokenization=true."); - } - for (int id : input) { - auto vocab = phrase_config_->vocab_array()->Get(id); - output_tokens.emplace_back(vocab->string_view()); - } - return output_tokens; -} - -absl::StatusOr PhraseTokenizer::Detokenize( - const absl::Span input) const { - SH_ASSIGN_OR_RETURN(std::vector output_tokens, - DetokenizeToTokens(input)); - if (split_end_punctuation_) { - std::string result; - for (const auto& token : output_tokens) { - if (special_tokens_.contains(token)) { - result += token; - } else { - result += ((result.empty() ? "" : " ") + token); - } - } - return result; - } else { - return absl::StrJoin(output_tokens, " "); - } -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/phrase_tokenizer.h b/tensorflow_text/core/kernels/phrase_tokenizer.h index b9f48eb72..f29774e7a 100644 --- a/tensorflow_text/core/kernels/phrase_tokenizer.h +++ b/tensorflow_text/core/kernels/phrase_tokenizer.h @@ -12,89 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_Phrase_TOKENIZER_H_ -#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_Phrase_TOKENIZER_H_ +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_H_ -#include -#include -#include +#include "tensorflow/core/kernels/text/phrase_tokenizer.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/random/random.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "tensorflow_text/core/kernels/phrase_tokenizer_model_generated.h" -#include "tensorflow_text/core/kernels/sentencepiece/double_array_trie.h" -#include "tensorflow_text/core/kernels/string_vocab.h" -#include "tensorflow_text/core/kernels/whitespace_tokenizer.h" - -namespace tensorflow { -namespace text { - -class PhraseTokenizer { - public: - // Creates an instance. - // - // Args: - // * config_flatbuffer: the pointer to the PhraseTokenizerConfig - // flatbuffer, which is not owned by this instance and should be kept - // alive through the lifetime of the instance. - static absl::StatusOr Create(const void* config_flatbuffer); - - // Tokenizes a string (or series of character codepoints) by Phrase. - // - // Example: - // input = "Show me the way." - // output = ["Show me", "the", "way."] - // - // The input should be UTF-8 but the tokenization is performed on Unicode - // codepoints. - // - // Args: - // * input: The UTF-8 string of an input. - // * tokens: The output tokens. - void Tokenize(const absl::string_view input, - std::vector* result_tokens, - std::vector* result_token_ids); - - // Detokenizer the input into a single string. - absl::StatusOr Detokenize( - const absl::Span input) const; - - private: - // Detokenizer the input into vector of strings. - absl::StatusOr> DetokenizeToTokens( - const absl::Span input) const; - - // Find the phrase tokens based on the current phrase. - void FindPhraseTokens(const std::string& cur_phrase, - std::vector* phrase_tokens, - std::vector* phrase_token_ids); - - // Lookup the phrase in the token string from current index. - // Args: - // * token: The input token string to find the next phrase. - // * cur: The starting point to search for the phrase. - // * in_trie: Whether there is a phrase in DoubleArrayTrie. - // * emitted_phrase_id: The emitted phrase id. - // * emitted_phrase_length: The length of the emitted phrase. - void PhraseLookup(const std::string& token, int cur, bool* in_trie, - int* emitted_phrase_id, int* emitted_phrase_length); - - std::unique_ptr vocab_ = nullptr; - const PhraseTokenizerConfig* phrase_config_; - absl::string_view whitespace_config_str_; - std::unique_ptr trie_ = nullptr; - float prob_; - absl::BitGen gen_; - std::unique_ptr whitespace_tokenizer_ = nullptr; - bool split_end_punctuation_ = false; - const absl::flat_hash_set special_tokens_ = { - "'t", "'s", ".", ",", "!", "?", "'m", "'re", "'ll", "'d", "'ve"}; -}; - -} // namespace text -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_Phrase_TOKENIZER_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_H_ diff --git a/tensorflow_text/core/kernels/phrase_tokenizer_kernel.cc b/tensorflow_text/core/kernels/phrase_tokenizer_kernel.cc deleted file mode 100644 index ac47ba777..000000000 --- a/tensorflow_text/core/kernels/phrase_tokenizer_kernel.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/phrase_tokenizer_kernel.h" - -#include "tensorflow/core/framework/op_kernel.h" - -namespace tensorflow { -namespace text { - -REGISTER_KERNEL_BUILDER( - Name(PhraseTokenizeOpKernel::OpName()).Device(tensorflow::DEVICE_CPU), - PhraseTokenizeOpKernel); - -REGISTER_KERNEL_BUILDER( - Name(PhraseDetokenizeOpKernel::OpName()).Device(tensorflow::DEVICE_CPU), - PhraseDetokenizeOpKernel); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/phrase_tokenizer_kernel.h b/tensorflow_text/core/kernels/phrase_tokenizer_kernel.h index 302b193df..61b876333 100644 --- a/tensorflow_text/core/kernels/phrase_tokenizer_kernel.h +++ b/tensorflow_text/core/kernels/phrase_tokenizer_kernel.h @@ -15,25 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_KERNEL_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_KERNEL_H_ -#include "tensorflow/lite/kernels/shim/tf_op_shim.h" -#include "tensorflow_text/core/kernels/phrase_tokenizer_kernel_template.h" - -namespace tensorflow { -namespace text { - -class PhraseTokenizeOpKernel - : public tflite::shim::TfOpKernel { - public: - using TfOpKernel::TfOpKernel; -}; - -class PhraseDetokenizeOpKernel - : public tflite::shim::TfOpKernel { - public: - using TfOpKernel::TfOpKernel; -}; - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/phrase_tokenizer_kernel.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_KERNEL_H_ diff --git a/tensorflow_text/core/kernels/phrase_tokenizer_kernel_template.h b/tensorflow_text/core/kernels/phrase_tokenizer_kernel_template.h index 67807a768..4f27754c0 100644 --- a/tensorflow_text/core/kernels/phrase_tokenizer_kernel_template.h +++ b/tensorflow_text/core/kernels/phrase_tokenizer_kernel_template.h @@ -15,346 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_KERNEL_TEMPLATE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_KERNEL_TEMPLATE_H_ -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/lite/kernels/shim/op_kernel.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow_text/core/kernels/phrase_tokenizer.h" +#include "tensorflow/core/kernels/text/phrase_tokenizer_kernel_template.h" -namespace tensorflow { -namespace text { - -// See `kDoc` data member for the documentation on this op kernel. -// -// This template class can be instantiated into a kernel for either TF or -// TFLite. See -// https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/kernels/shim -// for more info on how this works. -template -class PhraseTokenizeOp - : public tflite::shim::OpKernelShim { - private: - enum Inputs { kInputValues = 0, kPhraseModel }; - enum Outputs { - kOutputSubwords = 0, - kOutputIds, - kOutputRowSplits, - }; - - using Shape = tflite::shim::Shape; - using typename tflite::shim::OpKernelShim::InitContext; - using - typename tflite::shim::OpKernelShim::InvokeContext; - using typename tflite::shim::OpKernelShim::ShapeInferenceContext; - - public: - PhraseTokenizeOp() = default; - static constexpr char kOpName[] = "PhraseTokenize"; - static constexpr char kDoc[] = R"doc( - Tokenizes tokens into phrases based off of a vocabulary. - - ### Example: - - ```python - >>> tokens = ['I have a dream', 'I like coffee'] - >>> phrase, ids, row_splits = ( - ... phrase_tokenize(tokens, model_buffer)) - >>> RaggedTensor.from_row_splits(phrase, row_splits) - [['I', 'have', 'a dream'], ['I like', 'coffee']] - >>> RaggedTensor.from_row_splits(ids, row_splits) - [[0, 1, 2], [3, 4]] # Dummy ids. - ``` - - Args: - input_values: 1D Tensor of strings to tokenize with. - phrase_model: Buffer tensor for the PhraseTokenizerConfig flatbuffer. - - Returns: - * output_values: 1D tensor containing the phrases for all input strings. - A 2D RaggedTensor can be constructed from this and output_row_splits. - * output_ids: 1D tensor containing the phrase ids for all input strings. - A 2D RaggedTensor can be constructed from this and output_row_splits. - * output_row_splits: 1D int tensor with the row splits that allow us to - build RaggedTensors from output_values, output_ids. - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Attrs() { return {}; } - - // Input tensors declaration (syntax: - // https://www.tensorflow.org/guide/create_op) - static std::vector Inputs(); - - // Output tensors declaration (syntax: - // https://www.tensorflow.org/guide/create_op) - static std::vector Outputs(); - - // Initializes the op - absl::Status Init(InitContext* context) { return absl::OkStatus(); } - - // Runs the operation - absl::Status Invoke(InvokeContext* context); - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* c); -}; - -////////////////////////// Implementation - -template -std::vector PhraseTokenizeOp::Inputs() { - return {"input_values: string", "phrase_model: uint8"}; -} - -template -std::vector PhraseTokenizeOp::Outputs() { - return {"output_subwords: string", "output_ids: int64", - "output_row_splits: int64"}; -} - -template -absl::Status PhraseTokenizeOp::Invoke(InvokeContext* context) { - SH_ASSIGN_OR_RETURN(const auto input_values, context->GetInput(kInputValues)); - const auto& values_vec = input_values->template As(); - - SH_ASSIGN_OR_RETURN(const auto phrase_model, context->GetInput(kPhraseModel)); - // OK to create on every call because PhraseTokenizer is a - // lightweight, memory-mapped wrapper on `phrase_model` tensor, and thus - // Create() is very cheap. - auto phrase_tokenizer = ::tensorflow::text::PhraseTokenizer::Create( - phrase_model->template Data().data()); - SH_RETURN_IF_ERROR(phrase_tokenizer.status()); - - std::vector subwords; - std::vector subword_ids; - std::vector row_splits; - - row_splits.push_back(0); - - // Iterate through all the values and wordpiece tokenize them. - for (int i = 0; i < values_vec.Dim(0); ++i) { - // Tokenize into subwords and record the offset locations. - const int original_num_wordpieces = subwords.size(); - phrase_tokenizer->Tokenize(values_vec(i), &subwords, &subword_ids); - const int delta_num_wordpieces = subwords.size() - original_num_wordpieces; - - // Record the row splits. - row_splits.push_back(delta_num_wordpieces + row_splits.back()); - } - - const int subwords_size = subwords.size(); - SH_ASSIGN_OR_RETURN( - auto output_subwords, - context->GetOutput(kOutputSubwords, Shape({subwords_size}))); - auto output_subwords_vec = - output_subwords->template As(); - - SH_ASSIGN_OR_RETURN( - auto output_ids, - context->GetOutput( - kOutputIds, - Shape({static_cast( - subword_ids.size())}))); /* same shape as `output_subwords` */ - auto output_ids_vec = output_ids->template As(); - - SH_ASSIGN_OR_RETURN( - auto output_row_splits, - context->GetOutput(kOutputRowSplits, - Shape({static_cast(row_splits.size())}))); - auto output_row_splits_vec = output_row_splits->template As(); - - for (int i = 0; i < subwords.size(); ++i) { - output_subwords_vec(i) = subwords[i]; - } - - for (int i = 0; i < subword_ids.size(); ++i) { - output_ids_vec(i) = subword_ids[i]; - } - - for (int i = 0; i < row_splits.size(); ++i) { - output_row_splits_vec(i) = row_splits[i]; - } - - return absl::OkStatus(); -} - -template -absl::Status PhraseTokenizeOp::ShapeInference(ShapeInferenceContext* c) { - using tflite::shim::Shape; - SH_ASSIGN_OR_RETURN(const Shape input_values_shape, - c->GetInputShape(kInputValues)); - SH_ASSIGN_OR_RETURN(const auto phrase_model_shape, - c->GetInputShape(kPhraseModel)); - const auto rank_1_shape = Shape({Shape::kUnknownDim}); - if (!input_values_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be rank 1: ", input_values_shape.ToString())); - } - if (!phrase_model_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be rank 1: ", phrase_model_shape.ToString())); - } - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputSubwords, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputIds, rank_1_shape)); - // row splits size - const int num_splits = Shape::AddDims(1, input_values_shape.Dim(0)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputRowSplits, Shape({num_splits}))); - - return absl::OkStatus(); -} - -// See `kDoc` data member for the documentation on this op kernel. -// -// This template class can be instantiated into a kernel for either TF or -// TFLite. See -// https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/kernels/shim -// for more info on how this works. -template -class PhraseDetokenizeOp - : public tflite::shim::OpKernelShim { - private: - enum Inputs { kInputValues = 0, kInputRowSplits, kPhraseModel }; - enum Outputs { kOutputWords = 0 }; - - using Shape = tflite::shim::Shape; - using - typename tflite::shim::OpKernelShim::InitContext; - using typename tflite::shim::OpKernelShim::InvokeContext; - using typename tflite::shim::OpKernelShim::ShapeInferenceContext; - - public: - PhraseDetokenizeOp() = default; - static constexpr char kOpName[] = "TFText>PhraseDetokenize"; - static constexpr char kDoc[] = R"doc( - Detokenizes phrase ids into sentences. - - ### Example: - - ```python - >>> # Vocab of the model_buffer: ['I', 'have', 'a dream']. - >>> wordpiece_ids = [2, 3, 4] - >>> row_splits = [0, 2, 3] - >>> tokens = phrase_tokenizer_detokenize(tokens, row_splits, model_buffer) - >>> tokens - ['I have', 'a dream'] - ``` - - Args: - input_values: 1D Tensor of phrase ids. - input_row_splits: 1D Tensor of row splits that denotes the boundary of each - sentence in the `input_values`. - phrase_model: Buffer tensor for the PhraseTokenizerConfig flatbuffer. - - Returns: - * output_values: 1D tensor containing all the sentences. - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Attrs() { return {}; } - - // Input tensors declaration (syntax: - // https://www.tensorflow.org/guide/create_op) - static std::vector Inputs(); - - // Output tensors declaration (syntax: - // https://www.tensorflow.org/guide/create_op) - static std::vector Outputs(); - - // Initializes the op - absl::Status Init(InitContext* context) { return absl::OkStatus(); } - - // Runs the operation - absl::Status Invoke(InvokeContext* context); - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* c); -}; - -////////////////////////// Implementation - -template -std::vector PhraseDetokenizeOp::Inputs() { - return {"input_values: int32", "input_row_splits: int64", - "phrase_model: uint8"}; -} - -template -std::vector PhraseDetokenizeOp::Outputs() { - return {"output_words: string"}; -} - -template -absl::Status PhraseDetokenizeOp::Invoke(InvokeContext* context) { - SH_ASSIGN_OR_RETURN(const auto input_values, context->GetInput(kInputValues)); - const auto& values_vec = input_values->template As(); - - SH_ASSIGN_OR_RETURN(const auto input_row_splits, - context->GetInput(kInputRowSplits)); - const auto& row_splits_vec = input_row_splits->template As(); - - SH_ASSIGN_OR_RETURN(const auto phrase_model, context->GetInput(kPhraseModel)); - // OK to create on every call because PhraseTokenizer is a - // lightweight, memory-mapped wrapper on `phrase_model` tensor, and thus - // Create() is very cheap. - auto phrase_tokenizer = ::tensorflow::text::PhraseTokenizer::Create( - phrase_model->template Data().data()); - SH_RETURN_IF_ERROR(phrase_tokenizer.status()); - - std::vector sentences; - - // Iterate through row_splits to split input_values. - for (int i = 0; i < row_splits_vec.Dim(0) - 1; ++i) { - auto single_input = - absl::Span(values_vec.Ptr() + row_splits_vec(i), - row_splits_vec(i + 1) - row_splits_vec(i)); - SH_ASSIGN_OR_RETURN(auto sentence, - phrase_tokenizer->Detokenize(single_input)); - sentences.push_back(sentence); - } - - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - sentences, kOutputWords, context)); - - return absl::OkStatus(); -} - -template -absl::Status PhraseDetokenizeOp::ShapeInference(ShapeInferenceContext* c) { - using tflite::shim::Shape; - SH_ASSIGN_OR_RETURN(const Shape input_values_shape, - c->GetInputShape(kInputValues)); - SH_ASSIGN_OR_RETURN(const Shape input_row_splits_shape, - c->GetInputShape(kInputRowSplits)); - SH_ASSIGN_OR_RETURN(const auto phrase_model_shape, - c->GetInputShape(kPhraseModel)); - const auto rank_1_shape = Shape({Shape::kUnknownDim}); - if (!input_values_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be rank 1: ", input_values_shape.ToString())); - } - if (!input_row_splits_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError(absl::StrCat( - "Shape must be rank 1: ", input_row_splits_shape.ToString())); - } - if (!phrase_model_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be rank 1: ", phrase_model_shape.ToString())); - } - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputWords, rank_1_shape)); - return absl::OkStatus(); -} - -} // namespace text -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_phrase_TOKENIZER_KERNEL_TEMPLATE_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_KERNEL_TEMPLATE_H_ diff --git a/tensorflow_text/core/kernels/phrase_tokenizer_model.fbs b/tensorflow_text/core/kernels/phrase_tokenizer_model.fbs deleted file mode 100644 index d2d891cba..000000000 --- a/tensorflow_text/core/kernels/phrase_tokenizer_model.fbs +++ /dev/null @@ -1,38 +0,0 @@ -namespace tensorflow.text; - -table Trie { - nodes: [uint32]; -} - - -table PhraseTokenizerConfig { - // Probability of emitting a phrase when there is a match. - // The larger value means preferring shorter phrases over longer ones. - // I.e. 0 means always emit the longest possible phrase. - prob: int; - - // The unknown token string. - unk_token: string; - - // The unkown token id. - unk_token_id: int; - - // Whether the tokenizer supports detokenization function. - support_detokenization: bool; - - // Phrases Vocabulary array, this is for storting the phrase tokens in order, - // mainly used for detokenization. - vocab_array: [string]; - - // The trie is used to construct DoubleArrayTrie to do efficient prefix - // matching during tokenization. - vocab_trie: Trie; - - // whilte space config used to initalize the whitespace tokenzier. - whitespace_config: string; - - // Whether to split the end_puctualtion for each token. - split_end_punctuation: bool; -} - -root_type PhraseTokenizerConfig; diff --git a/tensorflow_text/core/kernels/phrase_tokenizer_model_builder.cc b/tensorflow_text/core/kernels/phrase_tokenizer_model_builder.cc deleted file mode 100644 index 268aeb32c..000000000 --- a/tensorflow_text/core/kernels/phrase_tokenizer_model_builder.cc +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/phrase_tokenizer_model_builder.h" - -#include - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow_text/core/kernels/phrase_tokenizer_model_generated.h" -#include "tensorflow_text/core/kernels/sentencepiece/double_array_trie_builder.h" -#include "tensorflow_text/core/kernels/string_vocab.h" -#include "tensorflow_text/core/kernels/whitespace_tokenizer_config_builder.h" - -namespace tensorflow { -namespace text { -namespace { - -// Builds the PhraseTokenizer model. -class PhraseBuilder { - public: - absl::Status BuildModel(const std::vector& vocab, - const std::string& unk_token, - bool support_detokenization, int prob, - bool split_end_punctuation); - - absl::StatusOr ExportToFlatBuffer() const; - - private: - std::unique_ptr vocab_; - std::vector trie_data_; - std::string unk_token_; - int unk_token_id_; - // Whether the tokenizer supports the detokenization function. - bool support_detokenization_; - int prob_; - bool split_end_punctuation_; -}; - -absl::Status PhraseBuilder::BuildModel(const std::vector& vocab, - const std::string& unk_token, - bool support_detokenization, int prob, - bool split_end_punctuation) { - unk_token_ = std::string(unk_token); - support_detokenization_ = support_detokenization; - prob_ = prob; - split_end_punctuation_ = split_end_punctuation; - - vocab_ = std::make_unique(vocab); - if (vocab_->Size() != vocab.size()) { - return absl::FailedPreconditionError( - "Tokens in the vocabulary must be unique."); - } - - // Determine `unk_token_id_`. - const absl::optional unk_token_id = vocab_->LookupId(unk_token_); - if (!unk_token_id.has_value()) { - return absl::FailedPreconditionError("Cannot find unk_token in the vocab!"); - } - unk_token_id_ = *unk_token_id; - - // build trie. - trie_data_ = sentencepiece::BuildTrie(vocab); - - return absl::OkStatus(); -} - -absl::StatusOr PhraseBuilder::ExportToFlatBuffer() const { - flatbuffers::FlatBufferBuilder builder; - - const auto unk_token = builder.CreateString(unk_token_); - - std::vector> vocab_fbs_vector; - - if (support_detokenization_) { - vocab_fbs_vector.reserve(vocab_->Size()); - for (int i = 0; i < vocab_->Size(); ++i) { - const absl::optional word = vocab_->LookupWord(i); - if (!word.has_value()) { - return absl::FailedPreconditionError( - "Impossible. `token_id` is definitely within the range of vocab " - "token ids; hence LookupWord() should always succeed."); - } - absl::string_view token = word.value(); - vocab_fbs_vector.emplace_back(builder.CreateString(token)); - } - } - - auto vocab_array = builder.CreateVector(vocab_fbs_vector); - - std::string ws_config = BuildWhitespaceTokenizerConfig(); - auto whitespace_config = builder.CreateString(ws_config); - auto trie_data = builder.CreateVector(trie_data_); - - TrieBuilder trie_builder(builder); - trie_builder.add_nodes(trie_data); - const auto trie_fbs = trie_builder.Finish(); - - PhraseTokenizerConfigBuilder wtcb(builder); - wtcb.add_unk_token(unk_token); - wtcb.add_unk_token_id(unk_token_id_); - wtcb.add_support_detokenization(support_detokenization_); - wtcb.add_vocab_array(vocab_array); - wtcb.add_whitespace_config(whitespace_config); - wtcb.add_vocab_trie(trie_fbs); - wtcb.add_prob(prob_); - wtcb.add_split_end_punctuation(split_end_punctuation_); - FinishPhraseTokenizerConfigBuffer(builder, wtcb.Finish()); - return std::string(reinterpret_cast(builder.GetBufferPointer()), - builder.GetSize()); -} -} // namespace - -absl::StatusOr BuildPhraseModelAndExportToFlatBuffer( - const std::vector& vocab, const std::string& unk_token, - bool support_detokenization, int prob, bool split_end_punctuation) { - PhraseBuilder builder; - SH_RETURN_IF_ERROR(builder.BuildModel( - vocab, unk_token, support_detokenization, prob, split_end_punctuation)); - SH_ASSIGN_OR_RETURN(std::string flatbuffer, builder.ExportToFlatBuffer()); - return flatbuffer; -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/phrase_tokenizer_model_builder.h b/tensorflow_text/core/kernels/phrase_tokenizer_model_builder.h index 86cd35b20..2f89b19e7 100644 --- a/tensorflow_text/core/kernels/phrase_tokenizer_model_builder.h +++ b/tensorflow_text/core/kernels/phrase_tokenizer_model_builder.h @@ -15,30 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_MODEL_BUILDER_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_MODEL_BUILDER_H_ -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/statusor.h" - -namespace tensorflow { -namespace text { - -// Builds a PhraseTokenizer model in flatbuffer format. -// -// Args: -// * vocab: The phrase vocabulary. -// * unk_token: The unknown token string. -//. * support_detokenization: Whether to enable the detokenization function. -// Setting it to true expands the size of the flatbuffer. -// * prob: Probability of emitting a phrase when there is a match. -// Returns: -// The bytes of the flatbuffer that stores the model. -absl::StatusOr BuildPhraseModelAndExportToFlatBuffer( - const std::vector& vocab, const std::string& unk_token, - bool support_detokenization = false, int prob = 0, - bool split_end_punctuation = false); -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/phrase_tokenizer_model_builder.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_MODEL_BUILDER_H_ diff --git a/tensorflow_text/core/kernels/utf8_binarize_kernel.cc b/tensorflow_text/core/kernels/phrase_tokenizer_model_generated.h similarity index 64% rename from tensorflow_text/core/kernels/utf8_binarize_kernel.cc rename to tensorflow_text/core/kernels/phrase_tokenizer_model_generated.h index 80049896f..8b7067bef 100644 --- a/tensorflow_text/core/kernels/utf8_binarize_kernel.cc +++ b/tensorflow_text/core/kernels/phrase_tokenizer_model_generated.h @@ -12,16 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "tensorflow_text/core/kernels/utf8_binarize_kernel.h" +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_MODEL_GENERATED_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_MODEL_GENERATED_H_ -#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/text/phrase_tokenizer_model_generated.h" -namespace tensorflow { -namespace text { - -REGISTER_KERNEL_BUILDER( - Name(Utf8BinarizeOpKernel::OpName()).Device(tensorflow::DEVICE_CPU), - Utf8BinarizeOpKernel); - -} // namespace text -} // namespace tensorflow +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_MODEL_GENERATED_H_ diff --git a/tensorflow_text/core/kernels/phrase_tokenizer_test.cc b/tensorflow_text/core/kernels/phrase_tokenizer_test.cc deleted file mode 100644 index e8ae06570..000000000 --- a/tensorflow_text/core/kernels/phrase_tokenizer_test.cc +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/phrase_tokenizer.h" - -#include -#include -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "tensorflow/core/platform/env.h" - -namespace tensorflow { -namespace text { -namespace { - -using ::testing::ElementsAre; - -/* With the following vocab: - -I -heard -the -news -today -have -heard news today -the news today -*/ -constexpr char kTestConfigPath[] = - "tensorflow_text/python/ops/test_data/" - "phrase_tokenizer_model.fb"; - -TEST(PhraseTokenizerTest, Tokenize) { - absl::string_view input("I heard the news today"); - std::vector output_tokens; - std::vector output_token_ids; - - std::string config_flatbuffer; - auto status = tensorflow::ReadFileToString( - tensorflow::Env::Default(), kTestConfigPath, &config_flatbuffer); - ASSERT_TRUE(status.ok()); - - ASSERT_OK_AND_ASSIGN(auto tokenizer, - PhraseTokenizer::Create(config_flatbuffer.data())); - - tokenizer.Tokenize(input, &output_tokens, &output_token_ids); - EXPECT_THAT(output_tokens, ElementsAre("I", "heard", "the news today")); - EXPECT_THAT(output_token_ids, ElementsAre(1, 2, 8)); -} - -TEST(PhraseTokenizerTest, TokenizeLonger) { - absl::string_view input("I heard the news today I heard"); - std::vector output_tokens; - std::vector output_token_ids; - - std::string config_flatbuffer; - auto status = tensorflow::ReadFileToString( - tensorflow::Env::Default(), kTestConfigPath, &config_flatbuffer); - ASSERT_TRUE(status.ok()); - - ASSERT_OK_AND_ASSIGN(auto tokenizer, - PhraseTokenizer::Create(config_flatbuffer.data())); - - tokenizer.Tokenize(input, &output_tokens, &output_token_ids); - EXPECT_THAT(output_tokens, - ElementsAre("I", "heard", "the news today", "I", "heard")); - EXPECT_THAT(output_token_ids, ElementsAre(1, 2, 8, 1, 2)); -} - -TEST(PhraseTokenizerTest, DeTokenize) { - std::vector input({1, 2, 8}); - - std::string config_flatbuffer; - auto status = tensorflow::ReadFileToString( - tensorflow::Env::Default(), kTestConfigPath, &config_flatbuffer); - ASSERT_TRUE(status.ok()); - - ASSERT_OK_AND_ASSIGN(auto tokenizer, - PhraseTokenizer::Create(config_flatbuffer.data())); - - auto output_string = tokenizer.Detokenize(input); - EXPECT_EQ(output_string.value(), "I heard the news today"); -} - -} // namespace -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/ragged_tensor_to_tensor_tflite.cc b/tensorflow_text/core/kernels/ragged_tensor_to_tensor_tflite.cc deleted file mode 100644 index 977cf5836..000000000 --- a/tensorflow_text/core/kernels/ragged_tensor_to_tensor_tflite.cc +++ /dev/null @@ -1,745 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "flatbuffers/flexbuffers.h" -#include "tensorflow/core/util/ragged_to_dense_util_common.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/mutable_op_resolver.h" -#include "tensorflow/lite/kernels/internal/types.h" -#include "tensorflow/lite/kernels/kernel_util.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { -namespace ragged_tensor_to_tensor { -namespace { - -constexpr int kShapeInput = 0; -constexpr int kValuesInput = 1; -constexpr int kDefaultValueInput = 2; -constexpr int kFirstPartitionInputIndex = 3; - -constexpr int kOutputTensor = 0; - -constexpr char kRowPartitionTypesAttr[] = "row_partition_types"; - -// The following three functions are copied from -// .../tensorflow/lite/kernels/internal/tensor_ctypes.h -// This header is not available in tensorflow package when building. -template -inline T* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? reinterpret_cast(tensor->data.raw) : nullptr; -} - -template -inline const T* GetTensorData(const TfLiteTensor* tensor) { - return tensor != nullptr ? reinterpret_cast(tensor->data.raw) - : nullptr; -} - -inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) { - if (tensor == nullptr) { - return RuntimeShape(); - } - - TfLiteIntArray* dims = tensor->dims; - const int dims_size = dims->size; - const int32_t* dims_data = reinterpret_cast(dims->data); - return RuntimeShape(dims_size, dims_data); -} - -struct ConversionAttributes { - std::vector partition_types; - int ragged_rank = 0; - - tensorflow::RowPartitionType GetRowPartitionTypeByDimension( - int dimension) const { - if (partition_types.front() == - tensorflow::RowPartitionType::FIRST_DIM_SIZE) { - return partition_types[dimension + 1]; - } else { - return partition_types[dimension]; - } - } -}; -template -int GetFirstDimensionSizeT(TfLiteContext* context, - const TfLiteTensor& first_partition_input, - const ConversionAttributes* attributes) { - const tensorflow::RowPartitionType first_partition_type = - attributes->partition_types.front(); - switch (first_partition_type) { - case tensorflow::RowPartitionType::FIRST_DIM_SIZE: - return *GetTensorData(&first_partition_input); - case tensorflow::RowPartitionType::VALUE_ROWIDS: - context->ReportError(context, - "Cannot handle VALUE_ROWIDS in first dimension."); - return -1; - case tensorflow::RowPartitionType::ROW_SPLITS: { - const auto shape = GetTensorShape(&first_partition_input); - return shape.Dims(0) - 1; - } - - default: - context->ReportError( - context, "Cannot handle type ", - RowPartitionTypeToString(first_partition_type).c_str()); - return -1; - } -} - -int GetFirstDimensionSize(TfLiteContext* context, - const TfLiteTensor& first_partition_input, - const ConversionAttributes* attributes) { - switch (first_partition_input.type) { - case kTfLiteInt32: - return GetFirstDimensionSizeT(context, first_partition_input, - attributes); - case kTfLiteInt64: - return GetFirstDimensionSizeT(context, first_partition_input, - attributes); - default: - context->ReportError(context, - "Not supported row partitioning tensor type"); - return -1; - } -} - -bool ValidateDefaultValueShape(TfLiteContext* context, - const RuntimeShape& default_value_shape, - const RuntimeShape& /*value_shape*/) { - // TF implementation also checks that shapes are not defined, not needed in - // TFLite. - // TODO(mgubin): Only scalar default value sizes are supported. - if (default_value_shape.FlatSize() != 1) { - context->ReportError(context, "Only scalar default value is supported"); - return false; - } - return true; -} - -RuntimeShape TensorShapeFromTensor(const TfLiteTensor& tensor) { - // TODO(mgubin): No checks, see - // third_party/tensorflow/core/kernels/list_kernels.cc - const RuntimeShape tensor_shape(tensor.dims->size, tensor.dims->data); - if (0 == tensor.dims->size) { - // If the input tensor is scalar then the shape is empty (also scalar). - return RuntimeShape{}; - } - RuntimeShape result(tensor_shape.FlatSize()); - switch (tensor.type) { - case kTfLiteInt32: { - for (int i = 0; i < tensor_shape.FlatSize(); ++i) { - result.SetDim(i, GetTensorData(&tensor)[i]); - } - } break; - case kTfLiteInt64: { - for (int i = 0; i < tensor_shape.FlatSize(); ++i) { - result.SetDim(i, GetTensorData(&tensor)[i]); - } - } break; - default: { - // Checked in Prepare. - } - } - return result; -} - -const TfLiteTensor* GetRowPartitionTensor( - const ConversionAttributes& conversion_attributes, TfLiteContext* context, - TfLiteNode* node, int dimension) { - if (conversion_attributes.partition_types.front() == - tensorflow::RowPartitionType::FIRST_DIM_SIZE) { - return &context->tensors[node->inputs->data[kFirstPartitionInputIndex + 1 + - dimension]]; - } else { - return &context->tensors[node->inputs - ->data[kFirstPartitionInputIndex + dimension]]; - } -} - -int GetMaxWidthValueRowID(const TfLiteTensor* tensor) { - const RuntimeShape tensor_shape(tensor->dims->size, tensor->dims->data); - const int index_length = tensor_shape.FlatSize(); - if (index_length == 0) { - return 0; - } - auto value_rowids = [tensor](int index) { - switch (tensor->type) { - case kTfLiteInt32: - return static_cast(tensor->data.i32[index]); - case kTfLiteInt64: - return static_cast(tensor->data.i64[index]); - default: - // TODO(mgubin): Add error checks. - return 0; - } - }; - int first_equal_index = 0; - int first_equal_index_value = value_rowids(0); - int max_width = 0; - for (int i = 0; i < index_length; ++i) { - const int value = value_rowids(i); - if (value != first_equal_index_value) { - first_equal_index_value = value; - max_width = std::max(i - first_equal_index, max_width); - first_equal_index = i; - } - } - return std::max(index_length - first_equal_index, max_width); -} - -int GetMaxWidthRowSplit(const TfLiteTensor* tensor) { - const RuntimeShape tensor_shape(tensor->dims->size, tensor->dims->data); - const int tensor_length = tensor_shape.FlatSize(); - if (tensor_length == 0 || tensor_length == 1) { - return 0; - } - auto value_rowsplit = [tensor](int index) { - switch (tensor->type) { - case kTfLiteInt32: - return static_cast(tensor->data.i32[index]); - case kTfLiteInt64: - return static_cast(tensor->data.i64[index]); - default: - // TODO(mgubin): Add error checks. - return 0; - } - }; - int max_width = 1; - int prev_split = value_rowsplit(0); - for (int i = 1; i < tensor_length; ++i) { - const int split = value_rowsplit(i); - max_width = std::max(max_width, split - prev_split); - prev_split = split; - } - return max_width; -} - -int GetMaxWidth(const ConversionAttributes& conversion_attributes, - TfLiteContext* context, TfLiteNode* node, int dimension) { - const TfLiteTensor* tensor = GetRowPartitionTensor( - conversion_attributes, context, node, dimension - 1); - switch (conversion_attributes.GetRowPartitionTypeByDimension(dimension - 1)) { - case tensorflow::RowPartitionType::VALUE_ROWIDS: - return GetMaxWidthValueRowID(tensor); - case tensorflow::RowPartitionType::ROW_SPLITS: - return GetMaxWidthRowSplit(tensor); - default: - context->ReportError(context, "Cannot handle partition type"); - return -1; - } -} - -RuntimeShape CombineRaggedTensorToTensorShapes( - int ragged_rank, const RuntimeShape& output_shape, - const RuntimeShape& value_shape) { - // TODO(mgubin): No checks, see - // third_party/tensorflow/core/ops/ragged_to_dense_util.cc - RuntimeShape result(output_shape); - if (output_shape.DimensionsCount() == 0) { - const int output_shape_rank = ragged_rank + value_shape.DimensionsCount(); - result.Resize(output_shape_rank); - for (int i = 0; i < output_shape_rank; ++i) { - result.SetDim(i, -1); - } - } - const int need_to_set = - output_shape.DimensionsCount() - value_shape.DimensionsCount(); - for (int i = 1; i < value_shape.DimensionsCount(); ++i) { - result.SetDim(need_to_set + i, value_shape.Dims(i)); - } - return result; -} - -RuntimeShape CalculateOutputSize( - const ConversionAttributes& conversion_attributes, TfLiteContext* context, - TfLiteNode* node, int first_dimension, int ragged_rank, - const TfLiteTensor& values, const TfLiteTensor& default_value, - const TfLiteTensor& output_shape) { - RuntimeShape values_shape(values.dims->size, values.dims->data); - RuntimeShape default_value_shape(default_value.dims->size, - default_value.dims->data); - - if (!ValidateDefaultValueShape(context, default_value_shape, values_shape)) { - return {}; - } - RuntimeShape output_shape_shape = TensorShapeFromTensor(output_shape); - - RuntimeShape result_shape = CombineRaggedTensorToTensorShapes( - ragged_rank, output_shape_shape, values_shape); - if (result_shape.Dims(0) < 0) { - result_shape.SetDim(0, first_dimension); - } - for (int i = 1; i <= ragged_rank; ++i) { - if (result_shape.Dims(i) < 0) { - result_shape.SetDim(i, - GetMaxWidth(conversion_attributes, context, node, i)); - } - } - return result_shape; -} - -TfLiteIntArray* IntArrayFromShape(const RuntimeShape& shape) { - TfLiteIntArray* result = TfLiteIntArrayCreate(shape.DimensionsCount()); - for (int i = 0; i < shape.DimensionsCount(); ++i) { - result->data[i] = shape.Dims(i); - } - return result; -} - -/** - * The output_index represents the index in the output tensor - * where the first element of a particular dimension would be written. - * If it is -1, it indicates that the index is out of scope. - * Example, given first_dimension = 10, first_dimension_output = 6, - * and output_index_multiplier = 100: - * result = [0 100 200 300 400 500 -1 -1 -1 -1] - * If first_dimension_output = 11 instead, then: - * result = [0 100 200 300 400 500 600 700 800 900] - */ -void CalculateFirstParentOutputIndex(int first_dimension, - int output_index_multiplier, - int first_dimension_output, - std::vector* result) { - const int min_dimension = std::min(first_dimension, first_dimension_output); - result->reserve(first_dimension); - int current_output_index = 0; - for (int i = 0; i < min_dimension; - ++i, current_output_index += output_index_multiplier) { - result->push_back(current_output_index); - } - for (int i = min_dimension; i < first_dimension; ++i) { - result->push_back(-1); - } -} -// Calculate the output index of the first element of a list. -// The parent_output_index is the same computation for the previous list. -// -1 indicates an element or list that is out of range. -// The output_index_multiplier is the number of output indices one moves -// forward for each column. -// E.g., given: -// value_rowids:[0 1 2 2 2 3 5 5 6] -// parent_output_index:[1000 1100 2000 2100 -1 3000 4000] -// output_index_multiplier: 10 -// output_size: 2 -// You get: -// result = [1000 1100 2000 2010 -1 2100 -1 -1 3000] -// result[0] = parent_output_index[value_rowids[0]] -// result[1] = parent_output_index[value_rowids[1]] -// result[2] = parent_output_index[value_rowids[2]] -// result[3] = parent_output_index[value_rowids[2] + 10] -// result[4] = -1 because it is the third element the size is 2. -// result[5] = parent_output_index[value_rowids[3]] -// result[6] = -1 because parent_output_index[value_rowids[6]] == -1 -// result[7] = -1 because parent_output_index[value_rowids[6]] == -1 -// result[8] = parent_output_index[value_rowids[7]] -void CalculateOutputIndexValueRowID(const TfLiteTensor& value_rowids, - const std::vector& parent_output_index, - int output_index_multiplier, - int output_size, std::vector* result) { - const RuntimeShape tensor_shape(value_rowids.dims->size, - value_rowids.dims->data); - const int index_size = tensor_shape.FlatSize(); - result->reserve(index_size); - if (index_size == 0) { - return; - } - - auto value_rowids_val = [value_rowids](int index) { - switch (value_rowids.type) { - case kTfLiteInt32: - return static_cast(value_rowids.data.i32[index]); - case kTfLiteInt64: - return static_cast(value_rowids.data.i64[index]); - default: - // TODO(mgubin): Add error checks. - return 0; - } - }; - int current_output_column = 0; - int current_value_rowid = value_rowids_val(0); - // DCHECK_LT(current_value_rowid, parent_output_index.size()); - int current_output_index = parent_output_index[current_value_rowid]; - result->push_back(current_output_index); - for (int i = 1; i < index_size; ++i) { - int next_value_rowid = value_rowids_val(i); - if (next_value_rowid == current_value_rowid) { - if (current_output_index >= 0) { - ++current_output_column; - if (current_output_column < output_size) { - current_output_index += output_index_multiplier; - } else { - current_output_index = -1; - } - } - } else { - current_output_column = 0; - current_value_rowid = next_value_rowid; - // DCHECK_LT(next_value_rowid, parent_output_index.size()); - current_output_index = parent_output_index[next_value_rowid]; - } - result->push_back(current_output_index); - } - // DCHECK_EQ(result->size(), value_rowids.size()); -} - -void CalculateOutputIndexRowSplit(const TfLiteTensor& row_split, - const std::vector& parent_output_index, - int output_index_multiplier, int output_size, - std::vector* result) { - const RuntimeShape row_split_shape(row_split.dims->size, - row_split.dims->data); - const int row_split_size = row_split_shape.FlatSize(); - auto row_split_val = [row_split](int index) { - switch (row_split.type) { - case kTfLiteInt32: - return static_cast(row_split.data.i32[index]); - case kTfLiteInt64: - return static_cast(row_split.data.i64[index]); - default: - // TODO(mgubin): Add error checks. - return 0; - } - }; - if (row_split_size > 0) { - result->reserve(row_split_val(row_split_size - 1)); - } - for (int i = 0; i < row_split_size - 1; ++i) { - const int row_length = row_split_val(i + 1) - row_split_val(i); - int real_length = std::min(output_size, row_length); - int parent_output_index_current = parent_output_index[i]; - - if (parent_output_index_current == -1) { - real_length = 0; - } - for (int j = 0; j < real_length; ++j) { - result->push_back(parent_output_index_current); - parent_output_index_current += output_index_multiplier; - } - for (int j = 0; j < row_length - real_length; ++j) { - result->push_back(-1); - } - } - // if (row_split_size > 0) { - // DCHECK_EQ(result->size(), row_split(row_split_size - 1)); - //} -} - -TfLiteStatus CalculateOutputIndex( - const ConversionAttributes& conversion_attributes, TfLiteContext* context, - TfLiteNode* node, int dimension, - const std::vector& parent_output_index, int output_index_multiplier, - int output_size, std::vector* result) { - const TfLiteTensor* row_partition_tensor = - GetRowPartitionTensor(conversion_attributes, context, node, dimension); - auto partition_type = - conversion_attributes.GetRowPartitionTypeByDimension(dimension); - switch (partition_type) { - case tensorflow::RowPartitionType::VALUE_ROWIDS: - CalculateOutputIndexValueRowID(*row_partition_tensor, parent_output_index, - output_index_multiplier, output_size, - result); - return kTfLiteOk; - case tensorflow::RowPartitionType::ROW_SPLITS: - CalculateOutputIndexRowSplit(*row_partition_tensor, parent_output_index, - output_index_multiplier, output_size, - result); - return kTfLiteOk; - default: - context->ReportError(context, "Unsupported partition type"); - return kTfLiteError; - } -} - -template -void SetOutputT(TfLiteContext* context, int ragged_rank, - const std::vector& output_index, - const TfLiteTensor& values_tensor, - const TfLiteTensor& default_value_tensor, - TfLiteTensor* output_tensor) { - const VALUE_TYPE* values_base = GetTensorData(&values_tensor); - VALUE_TYPE* output_base = GetTensorData(output_tensor); - const VALUE_TYPE* default_value = - GetTensorData(&default_value_tensor); - - RuntimeShape output_shape = GetTensorShape(output_tensor); - RuntimeShape element_shape = - RuntimeShape(output_shape.DimensionsCount() - ragged_rank - 1, - output_shape.DimsData() + ragged_rank + 1); - - // element_shape.RemoveDimRange(0, ragged_rank + 1); - const int value_element_size = element_shape.FlatSize(); - size_t output_index_size = output_index.size(); - - // Loop through the output_index vector, finding contiguous regions that - // should be copied. Once we find the end of a contiguous region, copy it - // and add any necessary padding (with default_value). - int src_start = 0; // Start of contiguous region (in values) - int dst_start = 0; // Destination for contiguous region (in output) - int dst_end = 0; // Destination for contiguous region (in output) - for (int src_i = 0; src_i <= output_index_size; ++src_i) { - // dst_i is the destination where the value at src_i should be copied. - int dst_i = src_i < output_index_size ? output_index[src_i] : -1; - - // If we're still in a contiguous region, then update dst_end go to the - // next src_i. - if (dst_i == dst_end) { - ++dst_end; - continue; - } - - // We found the end of contiguous region. This can be because we found - // a gap (dst_i > dst_end), or a source value that shouldn't be copied - // because it's out-of-bounds (dst_i == -1), or the end of the tensor - // (dst_i = -1). - if (dst_start < dst_end) { - // Copy the contiguous region. - const VALUE_TYPE* src = values_base + src_start * value_element_size; - VALUE_TYPE* dst = output_base + dst_start * value_element_size; - int nvals = (dst_end - dst_start) * value_element_size; - std::copy(src, src + nvals, dst); - // copy_array(dst, src, nvals); - } - - // Add any necessary padding (w/ default_value). - if (src_i >= output_index_size) { - // We reached the end of values: pad to the end of output. - const int output_size = output_shape.FlatSize(); - dst_i = output_size / value_element_size; - } - if (dst_i > dst_end) { - std::fill(output_base + dst_end * value_element_size, - output_base + dst_i * value_element_size, *default_value); - dst_end = dst_i; - } - - // Update indices. - if (dst_i < 0) { - // src_i should be skipped -- leave it out of the contiguous region. - src_start = src_i + 1; - dst_start = dst_end; - } else { - // src_i should be copied -- include it in the contiguous region. - src_start = src_i; - dst_start = dst_end; - dst_end = dst_start + 1; - } - } -} - -bool IsSupportedTensorType(TfLiteType type) { - // Should reflect SetOutput capabilities. - return type == kTfLiteInt32 || type == kTfLiteInt64 || type == kTfLiteFloat32; -} - -TfLiteStatus SetOutput(TfLiteContext* context, int ragged_rank, - const std::vector& output_index, - const TfLiteTensor& values_tensor, - const TfLiteTensor& default_value_tensor, - TfLiteTensor* output_tensor) { - switch (output_tensor->type) { - case kTfLiteInt32: - SetOutputT(context, ragged_rank, output_index, values_tensor, - default_value_tensor, output_tensor); - return kTfLiteOk; - case kTfLiteInt64: - SetOutputT(context, ragged_rank, output_index, values_tensor, - default_value_tensor, output_tensor); - return kTfLiteOk; - case kTfLiteFloat32: - SetOutputT(context, ragged_rank, output_index, values_tensor, - default_value_tensor, output_tensor); - return kTfLiteOk; - default: - // Should not happen, checked in Prepare. - // Left as a defensive programming artifact for future updates. - context->ReportError(context, "Not supported values type"); - return kTfLiteError; - } -} - -} // namespace - -void* Initialize(TfLiteContext* context, const char* buffer, size_t length) { - auto attributes = std::make_unique(); - - const uint8_t* buffer_t = reinterpret_cast(buffer); - - const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); - // TODO (mgubin): Converting flat buffer to a vector of strings looks not very - // effective but simple. A cleaner way is needed. - const flexbuffers::TypedVector row_partition_types_attr = - m[kRowPartitionTypesAttr].AsTypedVector(); - std::vector row_partition_types_attr_strings; - row_partition_types_attr_strings.reserve(row_partition_types_attr.size()); - for (int i = 0; i < row_partition_types_attr.size(); ++i) { - row_partition_types_attr_strings.emplace_back( - row_partition_types_attr[i].AsString().str()); - } - attributes->partition_types = - tensorflow::GetRowPartitionTypesHelper(row_partition_types_attr_strings); - if (attributes->partition_types.size() != - row_partition_types_attr_strings.size()) { - context->ReportError(context, "Can't parse partition type attribute"); - return nullptr; - } - attributes->ragged_rank = - tensorflow::GetRaggedRank(attributes->partition_types); - return attributes.release(); -} -void Free(TfLiteContext* /*context*/, void* buffer) { - ConversionAttributes* attributes = - reinterpret_cast(buffer); - delete attributes; -} - -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - const ConversionAttributes* attributes = - reinterpret_cast(node->user_data); - if (attributes == nullptr) { - // Parsing attributes failed, can't prepare. - context->ReportError(context, "Attributes are not initialized"); - return kTfLiteError; - } - TfLiteTensor& output_tensor = - context->tensors[node->outputs->data[kOutputTensor]]; - if (!IsSupportedTensorType(output_tensor.type)) { - context->ReportError(context, "Unsupported ragged tensor type"); - return kTfLiteError; - } - // The output tensor needs to be set to dynamic because it can have different - // size. - SetTensorToDynamic(&output_tensor); - - // Check that input shape tensor is int32 or int64 - TfLiteTensor& input_shape = context->tensors[node->inputs->data[kShapeInput]]; - if (input_shape.type != kTfLiteInt32 && input_shape.type != kTfLiteInt64) { - context->ReportError(context, - "Input shape tensor could be only int32 or int64"); - return kTfLiteError; - } - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const ConversionAttributes* attributes = - reinterpret_cast(node->user_data); - TfLiteTensor& input_shape = context->tensors[node->inputs->data[kShapeInput]]; - TfLiteTensor& input_values = - context->tensors[node->inputs->data[kValuesInput]]; - TfLiteTensor& default_value = - context->tensors[node->inputs->data[kDefaultValueInput]]; - // TODO (mgubin): Only scallar default value is supported. - if (RuntimeShape(default_value.dims->size, default_value.dims->data) - .FlatSize() != 1) { - context->ReportError(context, "Only scallar default value is supported"); - return kTfLiteError; - } - TfLiteTensor& first_partition_input = - context->tensors[node->inputs->data[kFirstPartitionInputIndex]]; - - // Calculate dimensions. - const int first_dimension = - GetFirstDimensionSize(context, first_partition_input, attributes); - if (first_dimension < 0) { - return kTfLiteError; - } - RuntimeShape output_shape = CalculateOutputSize( - *attributes, context, node, first_dimension, attributes->ragged_rank, - input_values, default_value, input_shape); - if (output_shape.DimensionsCount() == 0) { - return kTfLiteError; - } - - std::vector multiplier; - multiplier.resize(attributes->ragged_rank + 1); - multiplier.back() = 1; - for (int i = multiplier.size() - 2; i >= 0; --i) { - multiplier[i] = multiplier[i + 1] * output_shape.Dims(i + 1); - } - - // Allocate output tensor. - TfLiteTensor& output_tensor = - context->tensors[node->outputs->data[kOutputTensor]]; - - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, &output_tensor, - IntArrayFromShape(output_shape))); - - // Copy data. - const int full_size = multiplier.front() * output_shape.Dims(0); - if (full_size > 0) { - std::vector output_index, new_output_index; - int nvals = input_values.dims->data[0]; - output_index.reserve(nvals); - new_output_index.reserve(nvals); - - CalculateFirstParentOutputIndex(first_dimension, multiplier[0], - output_shape.Dims(0), &output_index); - for (int i = 1; i <= attributes->ragged_rank; ++i) { - TF_LITE_ENSURE_OK( - context, CalculateOutputIndex( - *attributes, context, node, i - 1, output_index, - multiplier[i], output_shape.Dims(i), &new_output_index)); - output_index.swap(new_output_index); - new_output_index.clear(); - } - - TF_LITE_ENSURE_OK(context, - SetOutput(context, attributes->ragged_rank, output_index, - input_values, default_value, &output_tensor)); - } - return kTfLiteOk; -} - -static TfLiteRegistration* GetTfLiteRegistration() { - static TfLiteRegistration r = {Initialize, Free, Prepare, Eval}; - return &r; -} - -} // namespace ragged_tensor_to_tensor - -extern "C" void AddRaggedTensorToTensor(tflite::MutableOpResolver* resolver) { - resolver->AddCustom("RaggedTensorToTensor", - ragged_tensor_to_tensor::GetTfLiteRegistration()); -} - -TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR() { - return ragged_tensor_to_tensor::GetTfLiteRegistration(); -} - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow_text/core/kernels/ragged_tensor_to_tensor_tflite.h b/tensorflow_text/core/kernels/ragged_tensor_to_tensor_tflite.h index 30586a094..bf587ca35 100644 --- a/tensorflow_text/core/kernels/ragged_tensor_to_tensor_tflite.h +++ b/tensorflow_text/core/kernels/ragged_tensor_to_tensor_tflite.h @@ -15,19 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_RAGGED_TENSOR_TO_TENSOR_TFLITE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_RAGGED_TENSOR_TO_TENSOR_TFLITE_H_ -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/mutable_op_resolver.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -extern "C" void AddRaggedTensorToTensor(::tflite::MutableOpResolver* resolver); - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite +#include "tensorflow/core/kernels/text/ragged_tensor_to_tensor_tflite.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_RAGGED_TENSOR_TO_TENSOR_TFLITE_H_ diff --git a/tensorflow_text/core/kernels/ragged_tensor_to_tensor_tflite_test.cc b/tensorflow_text/core/kernels/ragged_tensor_to_tensor_tflite_test.cc deleted file mode 100644 index 347f47126..000000000 --- a/tensorflow_text/core/kernels/ragged_tensor_to_tensor_tflite_test.cc +++ /dev/null @@ -1,317 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include -#include -#include "flatbuffers/flexbuffers.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/internal/tensor.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/test_util.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { -TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR(); -} // namespace text -} // namespace custom -} // namespace ops - -namespace { - -class RaggedTensorToTensorOpModel : public SingleOpModel { - public: - RaggedTensorToTensorOpModel(int output_shape_dims, - std::initializer_list values_shape, - std::initializer_list> - partition_tensors_shapes, - std::vector partition_types, - TensorType value_type = TensorType_FLOAT32, - TensorType index_type = TensorType_INT32, - bool allocate_and_delegate = true) { - // A structure to collect shapes for the input. - std::vector> shapes; - input_shape_ = AddInput(index_type); - shapes.push_back({output_shape_dims}); - input_values_ = AddInput(value_type); - shapes.emplace_back(values_shape); - input_default_values_ = AddInput(value_type); - shapes.push_back({1}); - for (const auto& p : partition_tensors_shapes) { - partition_tensors_.push_back(AddInput(TensorType_INT32)); - shapes.emplace_back(p); - } - output_ = AddOutput(value_type); - - flexbuffers::Builder fbb; - size_t start = fbb.StartMap(); - { - size_t start = fbb.StartVector("row_partition_types"); - for (const auto& s : partition_types) { - fbb.String(s); - } - fbb.EndVector(start, /*typed=*/true, /*fixed=*/false); - } - fbb.Int("num_row_partition_tensors", partition_types.size()); - fbb.EndMap(start); - fbb.Finish(); - SetCustomOp("RaggedTensorToTensor", fbb.GetBuffer(), - ops::custom::text::Register_RAGGED_TENSOR_TO_TENSOR); - BuildInterpreter(shapes, /*num_threads=*/-1, - /*allow_fp32_relax_to_fp16=*/false, - /*apply_delegate=*/true, - /*allocate_and_delegate=*/allocate_and_delegate); - } - - std::vector GetOutputShape() { return GetTensorShape(output_); } - - std::vector GetOutputFloat() { return ExtractVector(output_); } - std::vector GetOutputInt() { - return ExtractVector(output_); - } - - void InvokeFloat(const std::vector& shape, - const std::vector& values, float default_value, - const std::vector>& partition_values) { - PopulateTensor(input_shape_, shape); - PopulateTensor(input_values_, values); - PopulateTensor(input_default_values_, {default_value}); - for (int i = 0; i < partition_values.size(); ++i) { - PopulateTensor(partition_tensors_[i], partition_values[i]); - } - SingleOpModel::Invoke(); - } - void InvokeInt(const std::vector& shape, - const std::vector& values, int32_t default_value, - const std::vector>& partition_values) { - PopulateTensor(input_shape_, shape); - PopulateTensor(input_values_, values); - PopulateTensor(input_default_values_, {default_value}); - for (int i = 0; i < partition_values.size(); ++i) { - PopulateTensor(partition_tensors_[i], partition_values[i]); - } - SingleOpModel::Invoke(); - } - TfLiteStatus TryAllocateTensors() { return interpreter_->AllocateTensors(); } - - private: - int input_shape_; - int input_values_; - int input_default_values_; - std::vector partition_tensors_; - int output_; -}; - -TEST(RaggedTensorToTensorTest, RaggedTensorToTensor) { - // indices = [2, 1, 0, 3] - // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] - // params.shape = [4, None] - RaggedTensorToTensorOpModel model( - 2, // output_shape_dims - {9}, // values_shape - {{1}, {9}}, // partition_tensors_shapes - std::vector({"FIRST_DIM_SIZE", "VALUE_ROWIDS"})); - model.InvokeFloat({4, 4}, // shape - {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values - 1.5, // default_value - std::vector>( - {std::vector({4}), - std::vector({0, 0, 0, 2, 2, 2, 2, 3, 3})})); - EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({4, 4})); - EXPECT_THAT(model.GetOutputFloat(), - testing::ElementsAreArray({.1, .2, .3, 1.5, 1.5, 1.5, 1.5, 1.5, - .4, .5, .6, .7, .8, .9, 1.5, 1.5})); -} - -TEST(RaggedTensorToTensorTest, RaggedTensorToTensorRowSplits) { - // indices = [2, 1, 0, 3] - // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] - RaggedTensorToTensorOpModel model(2, // output_shape_dims - {9}, // values_shape - {{5}}, // partition_tensors_shapes - std::vector({"ROW_SPLITS"})); - model.InvokeFloat( - {4, 4}, // shape - {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values - 1.5, // default_value - std::vector>({std::vector({0, 3, 3, 7, 9})})); - EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({4, 4})); - EXPECT_THAT(model.GetOutputFloat(), - testing::ElementsAreArray({.1, .2, .3, 1.5, 1.5, 1.5, 1.5, 1.5, - .4, .5, .6, .7, .8, .9, 1.5, 1.5})); -} - -TEST(RaggedTensorToTensorTest, RaggedTensorToTensor_3DParams) { - // params = [ - // [[]], - // [[.1, .2], [.3]], - // [], - // [[.4, .5], [.6, .7, .8]], - // [[.9]] - // ] - RaggedTensorToTensorOpModel model( - 3, // output_shape_dims - {9}, // values_shape - {{1}, {6}, {9}}, // partition_tensors_shapes - std::vector( - {"FIRST_DIM_SIZE", "VALUE_ROWIDS", "VALUE_ROWIDS"})); - model.InvokeFloat( - {5, 2, 3}, // shape - {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values - 1.5, // default_value - std::vector>( - {std::vector({5}), std::vector({0, 1, 1, 3, 3, 4}), - std::vector({1, 1, 2, 3, 3, 4, 4, 4, 5})})); - - EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({5, 2, 3})); - EXPECT_THAT(model.GetOutputFloat(), - testing::ElementsAreArray({1.5, 1.5, 1.5, 1.5, 1.5, 1.5, .1, .2, - 1.5, .3, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, - 1.5, 1.5, .4, .5, 1.5, .6, .7, .8, - .9, 1.5, 1.5, 1.5, 1.5, 1.5})); -} - -TEST(RaggedTensorToTensorOpTest, RaggedTensorToTensor_3DParamsRowSplits) { - // params = [ - // [[]], - // [[.1, .2], [.3]], - // [], - // [[.4, .5], [.6, .7, .8]], - // [[.9]] - // ] - RaggedTensorToTensorOpModel model( - 3, // output_shape_dims - {9}, // values_shape - {{6}, {7}}, // partition_tensors_shapes - std::vector({"ROW_SPLITS", "ROW_SPLITS"})); - model.InvokeFloat( - {5, 2, 3}, // shape - {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values - 1.5, // default_value - std::vector>({std::vector({0, 1, 3, 3, 5, 6}), - std::vector({0, 0, 2, 3, 5, 8, 9})})); - EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({5, 2, 3})); - EXPECT_THAT(model.GetOutputFloat(), - testing::ElementsAreArray({1.5, 1.5, 1.5, 1.5, 1.5, 1.5, .1, .2, - 1.5, .3, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, - 1.5, 1.5, .4, .5, 1.5, .6, .7, .8, - .9, 1.5, 1.5, 1.5, 1.5, 1.5})); -} - -TEST(RaggedTensorToTensorTest, RaggedTensorToTensor_3DParamsRowSplits2) { - // params = [ - // [[0, 1, 2], []], - // [], - // [[3]] - // ] - - RaggedTensorToTensorOpModel model( - 3, // output_shape_dims - {4}, // values_shape - {{4}, {4}}, // partition_tensors_shapes - std::vector({"ROW_SPLITS", "ROW_SPLITS"}), TensorType_INT32); - model.InvokeInt( - {3, 2, 3}, // shape - {0, 1, 2, 3}, // values - 5, // default_value - std::vector>( - {std::vector({0, 2, 2, 3}), std::vector({0, 3, 3, 4})})); - - EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({3, 2, 3})); - - EXPECT_THAT(model.GetOutputInt(), - testing::ElementsAreArray( - {0, 1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 3, 5, 5, 5, 5, 5})); -} - -TEST(RaggedTensorToTensorTest, RaggedTensorToTensorContractExpanded) { - // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] - RaggedTensorToTensorOpModel model( - 2, // output_shape_dims - {9}, // values_shape - {{1}, {9}}, // partition_tensors_shapes - std::vector({"FIRST_DIM_SIZE", "VALUE_ROWIDS"})); - model.InvokeFloat({3, 5}, // shape - {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values - 1.5, // default_value - std::vector>( - {std::vector({4}), - std::vector({0, 0, 0, 2, 2, 2, 2, 3, 3})})); - EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({3, 5})); - - EXPECT_THAT(model.GetOutputFloat(), - testing::ElementsAreArray({.1, .2, .3, 1.5, 1.5, // - 1.5, 1.5, 1.5, 1.5, 1.5, // - .4, .5, .6, .7, 1.5})); -} - -// Adds a dense dimension. -TEST(RaggedTensorToTensorTest, RaggedTensorToTensorContractExpandedDense) { - // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] - RaggedTensorToTensorOpModel model( - 3, // output_shape_dims - {9, 2}, // values_shape - {{1}, {9}}, // partition_tensors_shapes - std::vector({"FIRST_DIM_SIZE", "VALUE_ROWIDS"})); - - model.InvokeFloat({3, 5, 2}, // shape - {.1, 1.1, .2, 1.2, .3, 1.3, .4, 1.4, .5, 1.5, .6, 1.6, .7, - 1.7, .8, 1.8, .9, 1.9}, // values - 1.5, // default_value - std::vector>( - {std::vector({4}), - std::vector({0, 0, 0, 2, 2, 2, 2, 3, 3})})); - - EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({3, 5, 2})); - EXPECT_THAT(model.GetOutputFloat(), - testing::ElementsAreArray( - {.1, 1.1, .2, 1.2, .3, 1.3, 1.5, 1.5, 1.5, 1.5, // - 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, // - .4, 1.4, .5, 1.5, .6, 1.6, .7, 1.7, 1.5, 1.5})); -} - -TEST(RaggedTensorToTensorTest, StringType) { - RaggedTensorToTensorOpModel model( - 2, // output_shape_dims - {9}, // values_shape - {{1}, {9}}, // partition_tensors_shapes - std::vector({"FIRST_DIM_SIZE", "VALUE_ROWIDS"}), - TensorType_STRING, TensorType_INT32, /*allocate_and_delegate=*/false); - EXPECT_EQ(model.TryAllocateTensors(), kTfLiteError); -} - -} // namespace -} // namespace tflite diff --git a/tensorflow_text/core/kernels/regex_split.cc b/tensorflow_text/core/kernels/regex_split.cc deleted file mode 100644 index add317d1a..000000000 --- a/tensorflow_text/core/kernels/regex_split.cc +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/regex_split.h" - -#include - -namespace tensorflow { -namespace text { -namespace { - -template -void RegexSplitImpl(absl::string_view input, const RE2& re2, - bool include_delimiter, const RE2& include_delim_regex, - std::vector* tokens, - std::vector* begin_offsets, - std::vector* end_offsets) { - absl::string_view leftover = input; - absl::string_view last_end = leftover; - - // Keep looking for split points until we have reached the end of the input. - absl::string_view extracted_delim_token; - while (RE2::FindAndConsume(&leftover, re2, &extracted_delim_token)) { - absl::string_view token(last_end.data(), - extracted_delim_token.data() - last_end.data()); - bool has_non_empty_token = token.length() > 0; - bool should_include_delim = - include_delimiter && include_delim_regex.FullMatch( - extracted_delim_token, include_delim_regex); - last_end = leftover; - - // Mark the end of the previous token, only if there was something. - if (has_non_empty_token) { - tokens->push_back(token); - // Mark the end of the last token - begin_offsets->push_back(token.data() - input.data()); - end_offsets->push_back(token.data() + token.length() - input.data()); - } - - if (should_include_delim) { - // If desired, include the deliminator as a token. - tokens->push_back(extracted_delim_token); - // Mark the end of the token at the end of the beginning of the delimiter. - begin_offsets->push_back(extracted_delim_token.data() - input.data()); - end_offsets->push_back(extracted_delim_token.data() + - extracted_delim_token.length() - input.data()); - } - } - - // Close the last token. - if (!leftover.empty()) { - tokens->push_back(leftover); - begin_offsets->push_back(leftover.data() - input.data()); - end_offsets->push_back(leftover.data() + leftover.length() - input.data()); - } -} - -} // namespace - -void RegexSplit(absl::string_view input, const RE2& re2, bool include_delimiter, - const RE2& include_delim_regex, - std::vector* tokens, - std::vector* begin_offsets, // NOLINT - std::vector* end_offsets) { // NOLINT - RegexSplitImpl(input, re2, include_delimiter, include_delim_regex, tokens, - begin_offsets, end_offsets); -} - -void RegexSplit(absl::string_view input, const RE2& re2, bool include_delimiter, - const RE2& include_delim_regex, - std::vector* tokens, - std::vector* begin_offsets, // NOLINT - std::vector* end_offsets) { // NOLINT - RegexSplitImpl(input, re2, include_delimiter, include_delim_regex, tokens, - begin_offsets, end_offsets); -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/regex_split.h b/tensorflow_text/core/kernels/regex_split.h index 770efaa7e..12867d4d3 100644 --- a/tensorflow_text/core/kernels/regex_split.h +++ b/tensorflow_text/core/kernels/regex_split.h @@ -12,31 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_TEXT_CORE_KERNELS_REGEX_SPLIT_H_ -#define TENSORFLOW_TEXT_CORE_KERNELS_REGEX_SPLIT_H_ +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_REGEX_SPLIT_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_REGEX_SPLIT_H_ -#include -#include +#include "tensorflow/core/kernels/text/regex_split.h" -#include "absl/strings/string_view.h" -#include "re2/re2.h" - -namespace tensorflow { -namespace text { - -void RegexSplit(absl::string_view input, const RE2& re2, bool include_delimiter, - const RE2& include_delim_regex, - std::vector* tokens, - std::vector* begin_offsets, // NOLINT - std::vector* end_offsets); // NOLINT - -void RegexSplit(absl::string_view input, const RE2& re2, bool include_delimiter, - const RE2& include_delim_regex, - std::vector* tokens, - std::vector* begin_offsets, // NOLINT - std::vector* end_offsets); // NOLINT - -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_TEXT_CORE_KERNELS_REGEX_SPLIT_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_REGEX_SPLIT_H_ diff --git a/tensorflow_text/core/kernels/regex_split_kernels.cc b/tensorflow_text/core/kernels/regex_split_kernels.cc deleted file mode 100644 index fd79a63d7..000000000 --- a/tensorflow_text/core/kernels/regex_split_kernels.cc +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow_text/core/kernels/regex_split.h" - -namespace tensorflow { -namespace text { - -class RegexSplitOp : public tensorflow::OpKernel { - public: - explicit RegexSplitOp(tensorflow::OpKernelConstruction* ctx) - : tensorflow::OpKernel(ctx) {} - - void Compute(tensorflow::OpKernelContext* ctx) override { - bool should_keep_delim; - std::shared_ptr delim_re; - std::shared_ptr keep_delim_re; - - // get regular expressions from input - const Tensor* delim_regex_pattern_tensor; - OP_REQUIRES_OK( - ctx, ctx->input("delim_regex_pattern", &delim_regex_pattern_tensor)); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(delim_regex_pattern_tensor->shape()), - errors::InvalidArgument( - "Pattern must be scalar, but received ", - delim_regex_pattern_tensor->shape().DebugString())); - const string delim_regex_pattern = - delim_regex_pattern_tensor->flat()(0); - delim_re = CachedDelimRE2(delim_regex_pattern); - OP_REQUIRES( - ctx, delim_re->ok(), - errors::InvalidArgument("Invalid pattern: ", delim_regex_pattern, - ", error: ", delim_re->error())); - - const Tensor* keep_delim_regex_pattern_tensor; - OP_REQUIRES_OK(ctx, ctx->input("keep_delim_regex_pattern", - &keep_delim_regex_pattern_tensor)); - OP_REQUIRES( - ctx, - TensorShapeUtils::IsScalar(keep_delim_regex_pattern_tensor->shape()), - errors::InvalidArgument( - "Pattern must be scalar, but received ", - keep_delim_regex_pattern_tensor->shape().DebugString())); - const string keep_delim_regex_pattern = - keep_delim_regex_pattern_tensor->flat()(0); - keep_delim_re = CachedKeepDelimRE2(keep_delim_regex_pattern); - OP_REQUIRES( - ctx, keep_delim_re->ok(), - errors::InvalidArgument("Invalid pattern: ", keep_delim_regex_pattern, - ", error: ", keep_delim_re->error())); - - should_keep_delim = keep_delim_re->pattern().empty() ? false : true; - - const Tensor* input_tensor; - OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); - const auto& input_flat = input_tensor->flat(); - - std::vector begin_offsets; - std::vector end_offsets; - std::vector tokens; - std::vector row_splits; - row_splits.push_back(0); - - for (size_t i = 0; i < input_flat.size(); ++i) { - RegexSplit(absl::string_view(input_flat(i).data()), *delim_re, - should_keep_delim, *keep_delim_re, &tokens, &begin_offsets, - &end_offsets); - row_splits.push_back(begin_offsets.size()); - } - - // Emit the flat Tensors needed to construct RaggedTensors for tokens, - // start, end offsets. - std::vector tokens_shape; - tokens_shape.push_back(tokens.size()); - - std::vector offsets_shape; - offsets_shape.push_back(begin_offsets.size()); - - std::vector row_splits_shape; - row_splits_shape.push_back(row_splits.size()); - - Tensor* output_tokens_tensor = nullptr; - OP_REQUIRES_OK(ctx, - ctx->allocate_output("tokens", TensorShape(tokens_shape), - &output_tokens_tensor)); - auto output_tokens = output_tokens_tensor->flat(); - - Tensor* output_begin_offsets_tensor = nullptr; - OP_REQUIRES_OK( - ctx, ctx->allocate_output("begin_offsets", TensorShape(offsets_shape), - &output_begin_offsets_tensor)); - auto output_begin_offsets = output_begin_offsets_tensor->flat(); - - Tensor* output_end_offsets_tensor = nullptr; - OP_REQUIRES_OK( - ctx, ctx->allocate_output("end_offsets", TensorShape(offsets_shape), - &output_end_offsets_tensor)); - auto output_end_offsets = output_end_offsets_tensor->flat(); - - Tensor* output_row_splits_tensor = nullptr; - OP_REQUIRES_OK( - ctx, ctx->allocate_output("row_splits", TensorShape(row_splits_shape), - &output_row_splits_tensor)); - auto output_row_splits = output_row_splits_tensor->flat(); - - // Copy outputs to Tensors. - for (size_t i = 0; i < tokens.size(); ++i) { - const auto& token = tokens[i]; - output_tokens(i) = tstring(token.data(), token.length()); - } - - for (size_t i = 0; i < begin_offsets.size(); ++i) { - output_begin_offsets(i) = begin_offsets[i]; - } - - for (size_t i = 0; i < end_offsets.size(); ++i) { - output_end_offsets(i) = end_offsets[i]; - } - - for (size_t i = 0; i < row_splits.size(); ++i) { - output_row_splits(i) = row_splits[i]; - } - } - - private: - std::shared_ptr CachedDelimRE2(const string& pattern) { - { - tf_shared_lock l(delim_mu_); - if (delim_re_ != nullptr && delim_re_->pattern() == pattern) { - return delim_re_; - } - } - // Construct the new RE2 object before acquiring the lock. - auto regex = std::make_shared(pattern); - { - mutex_lock l(delim_mu_); - // Swap instead of assigning so that we destruct the old - // RE2 object (when necessary) after releasing the lock. - delim_re_.swap(regex); - return delim_re_; - } - } - - std::shared_ptr CachedKeepDelimRE2(const string& pattern) { - { - tf_shared_lock l(keep_delim_mu_); - if (keep_delim_re_ != nullptr && keep_delim_re_->pattern() == pattern) { - return keep_delim_re_; - } - } - // Construct the new RE2 object before acquiring the lock. - auto regex = std::make_shared(pattern); - { - mutex_lock l(keep_delim_mu_); - // Swap instead of assigning so that we destruct the old - // RE2 object (when necessary) after releasing the lock. - keep_delim_re_.swap(regex); - return keep_delim_re_; - } - } - - mutex delim_mu_; - std::shared_ptr delim_re_ TF_GUARDED_BY(delim_mu_); - - mutex keep_delim_mu_; - std::shared_ptr keep_delim_re_ TF_GUARDED_BY(keep_delim_mu_); - - TF_DISALLOW_COPY_AND_ASSIGN(RegexSplitOp); -}; - -REGISTER_KERNEL_BUILDER( - Name("RegexSplitWithOffsets").Device(tensorflow::DEVICE_CPU), RegexSplitOp); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/regex_split_test.cc b/tensorflow_text/core/kernels/regex_split_test.cc deleted file mode 100644 index 045623746..000000000 --- a/tensorflow_text/core/kernels/regex_split_test.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/regex_split.h" - -#include -#include -#include "absl/strings/string_view.h" -#include "re2/re2.h" -#include "tensorflow/core/platform/tstring.h" - -namespace tensorflow { -namespace text { -namespace { - -std::vector RunTest(const tstring& input, - const tstring& regex, - const tstring& delim_regex) { - RE2 re2((absl::string_view(regex))); - RE2 include_delim_re2((absl::string_view(delim_regex))); - - std::vector begin_offsets; - std::vector end_offsets; - std::vector tokens; - - RegexSplit(input, re2, true, include_delim_re2, &tokens, &begin_offsets, - &end_offsets); - return tokens; -} - -TEST(RegexSplitTest, JapaneseAndWhitespace) { - tstring regex = "(\\p{Hiragana}+|\\p{Katakana}+|\\s)"; - tstring delim_regex = "(\\p{Hiragana}+|\\p{Katakana}+)"; - tstring input = "He said フランスです"; - auto extracted_tokens = RunTest(input, regex, delim_regex); - EXPECT_THAT(extracted_tokens, testing::ElementsAreArray({ - "He", - "said", - "フランス", - "です", - })); -} - -TEST(RegexSplitTest, Japanese) { - tstring regex = "(\\p{Hiragana}+|\\p{Katakana}+)"; - tstring input = "He said フランスです"; - auto extracted_tokens = RunTest(input, regex, regex); - EXPECT_THAT(extracted_tokens, testing::ElementsAreArray({ - "He said ", - "フランス", - "です", - })); -} - -TEST(RegexSplitTest, ChineseHan) { - tstring regex = "(\\p{Han})"; - tstring input = "敵人變盟友背後盤算"; - auto extracted_tokens = RunTest(input, regex, regex); - EXPECT_THAT(extracted_tokens, - testing::ElementsAreArray( - {"敵", "人", "變", "盟", "友", "背", "後", "盤", "算"})); -} - -} // namespace -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/rouge_l_kernel.cc b/tensorflow_text/core/kernels/rouge_l_kernel.cc deleted file mode 100644 index 50a730f47..000000000 --- a/tensorflow_text/core/kernels/rouge_l_kernel.cc +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include "tensorflow/core/framework/lookup_interface.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/platform/logging.h" - -namespace tensorflow { -namespace text { - -namespace { -} // namespace - - -// ROUGE-L implementation based on -// https://www.microsoft.com/en-us/research/publication/ -// rouge-a-package-for-automatic-evaluation-of-summaries/ -template -class RougeLOp : public OpKernel { - public: - using ConstFlatSplits = typename TTypes::ConstFlat; - using ConstFlatValues = typename TTypes::ConstFlat; - - explicit RougeLOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - const Tensor& hyp_tensor = ctx->input(0); - const auto hyp_tensor_flat = hyp_tensor.flat(); - const Tensor& hyp_splits = ctx->input(1); - const auto hyp_splits_flat = hyp_splits.flat(); - - const Tensor& ref_tensor = ctx->input(2); - const auto ref_tensor_flat = ref_tensor.flat(); - const Tensor& ref_splits = ctx->input(3); - const auto ref_splits_flat = ref_splits.flat(); - - const Tensor& alpha_tensor = ctx->input(4); - const auto alpha_scalar = alpha_tensor.scalar(); - const float alpha = alpha_scalar(); - - // Alpha must be <=1. - OP_REQUIRES(ctx, alpha <= 1, - errors::InvalidArgument("alpha must be <1 but was=", alpha)); - - // Ref and Hyp must have the same number of rows. - OP_REQUIRES(ctx, ref_splits_flat.size() == hyp_splits_flat.size(), - errors::InvalidArgument( - "ref splits len=", ref_splits_flat.size(), - "must equal hyp splits len=", hyp_splits_flat.size())); - - // All inputs must be vectors. - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(hyp_tensor.shape()), - errors::InvalidArgument("hypotheses values must be a vector")); - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ref_tensor.shape()), - errors::InvalidArgument("references values must be a vector")); - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(hyp_splits.shape()), - errors::InvalidArgument("hypotheses splits must be a vector")); - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ref_splits.shape()), - errors::InvalidArgument("references splits must be a vector")); - // Ref and Hyp must have at least one split. - OP_REQUIRES(ctx, ref_splits_flat.size() > 0, - errors::InvalidArgument( - "ref splits len=0; must have at least 1 split")); - - // Output is a dense Tensor containing one row per input row. - TensorShape output_shape({ref_splits_flat.size() - 1}); - - // Allocate the F-Measure output tensor. - Tensor* f_measure_tensor; - OP_REQUIRES_OK(ctx, ctx->allocate_output("f_measure", output_shape, - &f_measure_tensor)); - auto f_measures_flat = f_measure_tensor->flat(); - - // Allocate the P-Measure output tensor. - Tensor* p_measure_tensor; - OP_REQUIRES_OK(ctx, ctx->allocate_output("p_measure", output_shape, - &p_measure_tensor)); - auto p_measures_flat = p_measure_tensor->flat(); - - // Allocate the R-Measure output tensor. - Tensor* r_measure_tensor; - OP_REQUIRES_OK(ctx, ctx->allocate_output("r_measure", output_shape, - &r_measure_tensor)); - auto r_measures_flat = r_measure_tensor->flat(); - - // Iterate over the splits, skipping the first split as it is always zero. - for (int i = 1; i < hyp_splits_flat.size(); i++) { - // Length of hyp and ref. - SPLITS_TYPE lhyp = hyp_splits_flat(i) - hyp_splits_flat(i-1); - SPLITS_TYPE lref = ref_splits_flat(i) - ref_splits_flat(i-1); - // Length of longest common substring. - int32 llcs = LongestCommonSubsequenceLength(hyp_splits_flat(i-1), - hyp_splits_flat(i), - hyp_tensor_flat, - ref_splits_flat(i-1), - ref_splits_flat(i), - ref_tensor_flat); - auto measures = ComputeMeasures(lhyp, lref, llcs, alpha); - f_measures_flat(i - 1) = std::get<0>(measures); - p_measures_flat(i - 1) = std::get<1>(measures); - r_measures_flat(i - 1) = std::get<2>(measures); - } - } - - private: - // By using LCS, the ROUGE-L algorithm does not require consecutive matches - // but rather credits the order of N-grams. - int32 LongestCommonSubsequenceLength( - const SPLITS_TYPE hyp_i, - const SPLITS_TYPE hyp_j, - const ConstFlatValues& hyp, - const SPLITS_TYPE ref_i, - const SPLITS_TYPE ref_j, - const ConstFlatValues& ref) { - SPLITS_TYPE lhyp = hyp_j - hyp_i; - SPLITS_TYPE lref = ref_j - ref_i; - // Create a scratch matrix to keep track of the LCS seen so far using DP. - // http://www.algorithmist.com/index.php/Longest_Common_Subsequence - Tensor scratch(DT_INT32, {lhyp + 2, lref + 2}); - auto scratch2d = scratch.matrix(); - for (SPLITS_TYPE x = hyp_i; x <= hyp_j + 1; x++) { - for (SPLITS_TYPE y = ref_i; y <= ref_j + 1; y++) { - SPLITS_TYPE a = x - hyp_i; - SPLITS_TYPE b = y - ref_i; - if (a == 0 || b == 0) { - // If in first row or column, we write a zero to the table. - scratch2d(a, b) = 0; - } else if (x == hyp_j+1 || y == ref_j+1 || hyp(x-1) != ref(y-1)) { - // If in the last row or column, or if the tokens are not equal, - // carry the largest score seen in the cell above or to the left of - // the current cell. - scratch2d(a, b) = - std::max({scratch2d(a - 1, b), scratch2d(a, b - 1)}); - } else { - // If tokens are equal, we are part of a subsequence, so increment the - // diagonal score. - scratch2d(a, b) = scratch2d(a - 1, b - 1) + 1; - } - } - } - return scratch2d(lhyp, lref); - } - - std::tuple ComputeMeasures(const SPLITS_TYPE lhyp_int, - const SPLITS_TYPE lref_int, - const int32 llcs_int, - const float alpha) { - const float lhyp = static_cast(lhyp_int); - const float lref = static_cast(lref_int); - const float llcs = static_cast(llcs_int); - const float p_lcs = llcs / (lhyp + 1e-12); - const float r_lcs = llcs / (lref + 1e-12); - // Use the tensor2tensor formulation if the alpha value is <0, - // which does not make sense as a weighted average term. - const float f_lcs = alpha < 0 ? - ComputeTensor2TensorF(p_lcs, r_lcs) : - ComputeOfficialF(p_lcs, r_lcs, alpha); - return std::make_tuple(f_lcs, p_lcs, r_lcs); - } - - float ComputeTensor2TensorF(const float p_lcs, const float r_lcs) { - const float beta = p_lcs / (r_lcs + 1e-12); - const float numerator = (1 + (beta * beta)) * r_lcs * p_lcs; - const float denominator = r_lcs + ((beta * beta) * p_lcs); - if (denominator > 0) { - return numerator / denominator; - } - return 0; - } - - float ComputeOfficialF(const float p_lcs, const float r_lcs, - const float alpha) { - float denominator = (alpha * r_lcs + (1 - alpha) * p_lcs); - if (denominator > 0) { - return (p_lcs * r_lcs) / denominator; - } - return denominator; - } - - TF_DISALLOW_COPY_AND_ASSIGN(RougeLOp); -}; - -#define REGISTER(VALUES_TYPE) \ - REGISTER_KERNEL_BUILDER(Name("RougeL") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("Tsplits") \ - .TypeConstraint("Tvalues"), \ - RougeLOp); \ - REGISTER_KERNEL_BUILDER(Name("RougeL") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("Tsplits") \ - .TypeConstraint("Tvalues"), \ - RougeLOp); - -TF_CALL_int32(REGISTER); -TF_CALL_int64(REGISTER); -TF_CALL_string(REGISTER); -#undef REGISTER - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/rouge_l_kernel_test.cc b/tensorflow_text/core/kernels/rouge_l_kernel_test.cc deleted file mode 100644 index 3dc00a120..000000000 --- a/tensorflow_text/core/kernels/rouge_l_kernel_test.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/shape_inference_testutil.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { - -TEST(RougeLFMeasureOpTest, ShapeFn) { - ShapeInferenceTestOp op("RougeL"); - - INFER_OK(op, "[?];[3];[?];[3];[]", "[2];[2];[2]"); - INFER_OK(op, "[5];[3];[?];[3];[]", "[2];[2];[2]"); - INFER_OK(op, "[?];[3];[8];[3];[]", "[2];[2];[2]"); - INFER_OK(op, "[5];[3];[8];[3];[]", "[2];[2];[2]"); - INFER_OK(op, "[5];[3];[8];?;[]", "[2];[2];[2]"); - INFER_OK(op, "[5];?;[8];[3];[]", "[2];[2];[2]"); - INFER_OK(op, "[5];[?];[8];[?];[]", "[?];[?];[?]"); - INFER_OK(op, "?;?;?;?;?", "[?];[?];[?]"); - INFER_ERROR("Dimension 0 in both shapes must be equal, but are 3 and 2.", op, - "[5];[3];[8];[2];[]"); - INFER_ERROR("Shape must be rank 0 but is rank 1", op, - "[5];[3];[8];[3];[1]"); -} - -} // namespace -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/round_robin_trimmer.h b/tensorflow_text/core/kernels/round_robin_trimmer.h index 5273dfa9e..7a8b7014f 100644 --- a/tensorflow_text/core/kernels/round_robin_trimmer.h +++ b/tensorflow_text/core/kernels/round_robin_trimmer.h @@ -15,304 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_H_ -#include -#include -#include -#include -#include "tensorflow_text/core/kernels/trimmer.h" - - -namespace tensorflow { -namespace text { - -template -class RoundRobinTrimmer : Trimmer, BatchTrimmer { - using Values_ = Values; - using ValuesSpan_ = ValuesSpan; - using RowSplits_ = RowSplits; - using RowSplitsSpan_ = RowSplitsSpan; - - public: - RoundRobinTrimmer(int max_sequence_length) - : max_sequence_length_(std::max(max_sequence_length, 0)) {} - virtual ~RoundRobinTrimmer() = default; - - // Generates masks for a single batch of values. - std::vector GenerateMasks( - const std::vector& values) const; - - // Generates masks for a batch of values row splits. - // - // Args: - // row_splits: Row splits of the values in the shape [batch, (num values)] - // - // Returns: - // The returned value is a flattened list of mask values which can be split - // into batches using the same input row splits. - std::vector GenerateMasksBatch( - const std::vector& row_splits) const; - std::vector GenerateMasksBatch( - const std::vector& row_splits) const; - - // Trims a single batch of values. - void Trim(std::vector* values) const; - - // Trims a batch of values given their flattened values and row splits. - // - // Args: - // flat_values: Flattened values in shape [batch, (num values)] - // row_splits: Row splits of the values in the shape [batch, (num values)] - // - // Returns: - // The returned values are the flattened trimmed values and new row splits. - std::pair, std::vector> TrimBatch( - const std::vector& flat_values, - const std::vector& row_splits) const; - std::pair, std::vector> TrimBatch( - const std::vector& flat_values, - const std::vector& row_splits) const; - - protected: - // Used for holding data about value sizes and how much of it is used. - struct Row { - Row() : idx(0), size(0), used(0) {} - Row(int idx, int size, int used) : idx(idx), size(size), used(used) {} - int idx; // Index into the list of values - Tsplits size; // Size of the row values - int used; // How much of the values is used - }; - - // Internal execution to share code for Span & Vector row_splits. - template - std::vector GenerateMasksInternal(Iterator begin, Iterator end) const; - - // Internal execution to share code for Span & Vector row_splits. - template - std::pair, std::vector> TrimInternal( - ValuesIterator flat_values_begin, - ValuesIterator flat_values_end, - RowSplitsIterator row_splits_begin, - RowSplitsIterator row_splits_end) const; - - // Main process of the timmer. Process row splits a batch at a time. Once each - // it is known how much each row in a batch is used, the callback is called - // with the row information. - // Algorithm to fill values: - // 1. Fill values that will max starting from smallest to largest. - // 2. Partially fill the rest up the same amount up to the sequence length. - // 3. Add the remainder to the available rows in order. - template - void ProcessBatch(Iterator values_begin, Iterator values_end, - std::function*)> callback) const; - void ProcessBatch(std::vector* value_row_sizes, - std::function*)> callback) const; - - template - void ProcessSplitsByBatch(Iterator begin, Iterator end, - std::function*)> callback) const; - - const int max_sequence_length_; -}; - -/******************************* Implementation *******************************/ - -template -std::vector RoundRobinTrimmer::GenerateMasks( - const std::vector& values) const { - std::vector masks(values.size()); - ProcessBatch(values.begin(), values.end(), - [&masks](std::vector* value_row_sizes) { - for (int i = 0; i < masks.size(); ++i) { - Mask& mask = masks[i]; - const Row& values_row = (*value_row_sizes)[i]; - mask.reserve(values_row.size); - mask.insert(mask.end(), values_row.used, true); - mask.insert(mask.end(), values_row.size - values_row.used, false); - } - }); - return masks; -} - -template -std::vector RoundRobinTrimmer::GenerateMasksBatch( - const std::vector& row_splits) const { - return GenerateMasksInternal(row_splits.begin(), row_splits.end()); -} - -template -std::vector RoundRobinTrimmer::GenerateMasksBatch( - const std::vector& row_splits) const { - return GenerateMasksInternal(row_splits.begin(), row_splits.end()); -} - -template -template -std::vector RoundRobinTrimmer::GenerateMasksInternal( - const Iterator begin, const Iterator end) const { - // First reserve necessary space for the masks - std::vector masks(end - begin); - auto m = masks.begin(); - for (auto it = begin; it != end; ++it, ++m) { - m->reserve(it->back()); - } - // Process all batches, updating the masks a batch at a time. - ProcessSplitsByBatch(begin, end, [&masks](std::vector* rows) { - for (int s = 0; s < masks.size(); ++s) { - const Row& row = (*rows)[s]; - masks[s].reserve(row.size); - masks[s].insert(masks[s].end(), row.used, true); - masks[s].insert(masks[s].end(), row.size - row.used, false); - } - }); - return masks; -} - -template -void RoundRobinTrimmer::Trim(std::vector* values) const { - ProcessBatch(values->begin(), values->end(), - [values] (std::vector* value_row_sizes) { - for (int s = 0; s < values->size(); ++s) { - (*values)[s].resize((*value_row_sizes)[s].used); - } - }); -} - -template -std::pair>, std::vector>> -RoundRobinTrimmer::TrimBatch( - const std::vector& flat_values, - const std::vector& row_splits) const { - return TrimInternal( - flat_values.begin(), flat_values.end(), - row_splits.begin(), row_splits.end()); -} - -template -std::pair>, std::vector>> -RoundRobinTrimmer::TrimBatch( - const std::vector& flat_values, - const std::vector& row_splits) const { - return TrimInternal( - flat_values.begin(), flat_values.end(), - row_splits.begin(), row_splits.end()); -} - -template -template -std::pair>, std::vector>> -RoundRobinTrimmer::TrimInternal( - ValuesIterator flat_values_begin, - ValuesIterator flat_values_end, - RowSplitsIterator splits_begin, - RowSplitsIterator splits_end) const { - std::pair, std::vector> trimmed( - {std::vector(flat_values_end - flat_values_begin), - std::vector(splits_end - splits_begin)}); - // All row splits start at index 0 - for (int i = 0; i < trimmed.second.size(); ++i) { - trimmed.second[i].push_back({0}); - } - ProcessSplitsByBatch(splits_begin, splits_end, - [&trimmed, flat_values_begin, splits_begin](std::vector* values_row) - { - auto values_it = flat_values_begin; - auto splits_it = splits_begin; - for (int s = 0; s < values_row->size(); ++s, ++values_it, ++splits_it) { - Values_* vals = &trimmed.first[s]; - RowSplits_* splits = &trimmed.second[s]; - auto start = values_it->begin() + (*splits_it)[splits->size()-1]; - vals->insert(vals->end(), start, start + (*values_row)[s].used); - splits->insert(splits->end(), splits->back() + (*values_row)[s].used); - } - }); - return trimmed; -} - -template -template -void RoundRobinTrimmer::ProcessBatch( - Iterator values_begin, Iterator values_end, - std::function*)> callback) const { - int num_values = values_end - values_begin; - // Get size of each segment - std::vector value_row_sizes(num_values); - int i = 0; - for (auto it = values_begin; it != values_end; ++it, ++i) { - value_row_sizes[i].idx = i; - value_row_sizes[i].size = it->size(); - } - // Process the values - ProcessBatch(&value_row_sizes, callback); -} - -template -void RoundRobinTrimmer::ProcessBatch( - std::vector* value_row_sizes, - std::function*)> callback) const { - int num_values = value_row_sizes->size(); - int sequence_left = max_sequence_length_; - - // Fill all values to the max (smallest first to largest) that we can - // without crossing the max_sequence_length - std::sort(value_row_sizes->begin(), value_row_sizes->end(), - [] (Row a, Row b) { return a.size < b.size; }); - int filled_value_rows = 0; - for (int i = 0; i < num_values; ++i) { - // Break if we will not be able to fill up the smallest unfilled value row - if ((*value_row_sizes)[i].size * (num_values - filled_value_rows) - > sequence_left) { - break; - } - (*value_row_sizes)[i].used = (*value_row_sizes)[i].size; - sequence_left -= (*value_row_sizes)[i].used; - ++filled_value_rows; - } - - // Fill the remaining value rows evenly - if (filled_value_rows < num_values) { - int count = sequence_left / (num_values - filled_value_rows); - for (int i = filled_value_rows; i < num_values; ++i) { - (*value_row_sizes)[i].used = count; - sequence_left -= count; - } - } - - // Finally add the remainder - index order - std::sort(value_row_sizes->begin(), value_row_sizes->end(), - [] (Row a, Row b) { return a.idx < b.idx; }); - for (int i = 0; i < num_values && sequence_left > 0; ++i) { - if ((*value_row_sizes)[i].used < (*value_row_sizes)[i].size) { - ++((*value_row_sizes)[i].used); - --sequence_left; - } - } - - // Usage of rows computed. Execute callback to process. - callback(value_row_sizes); -} - -template -template -void RoundRobinTrimmer::ProcessSplitsByBatch( - Iterator begin, Iterator end, - std::function*)> callback) const { - int num_in_batch = begin->size() - 1; - int num_values = end - begin; - // Process one batch at a time. - std::vector value_row_sizes(num_values); - for (int batch_idx = 0; batch_idx < num_in_batch; ++batch_idx) { - // First, get size of each row. - int idx = 0; - for (auto i = begin; i < end; ++i, ++idx) { - value_row_sizes[idx].idx = idx; - value_row_sizes[idx].size = (*i)[batch_idx + 1] - (*i)[batch_idx]; - } - // Perform the main processing of the batch - ProcessBatch(&value_row_sizes, callback); - } -} - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/round_robin_trimmer.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_H_ diff --git a/tensorflow_text/core/kernels/round_robin_trimmer_kernel.cc b/tensorflow_text/core/kernels/round_robin_trimmer_kernel.cc deleted file mode 100644 index 4635d7b45..000000000 --- a/tensorflow_text/core/kernels/round_robin_trimmer_kernel.cc +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/round_robin_trimmer_kernel.h" - -#include - -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" - -namespace tensorflow { -namespace text { - -using RoundRobinGenerateMasksOpKernelInstance = - RoundRobinGenerateMasksOpKernel; - -#define REGISTER_ROUND_ROBIN_GENERATE_MASKS_SPLITS(vals_type, splits_type) \ - REGISTER_KERNEL_BUILDER( \ - Name(RoundRobinGenerateMasksOpKernelInstance::OpName()) \ - .Device(tensorflow::DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tsplits"), \ - RoundRobinGenerateMasksOpKernel); - -#define REGISTER_ROUND_ROBIN_GENERATE_MASKS(vals_type) \ - REGISTER_ROUND_ROBIN_GENERATE_MASKS_SPLITS(vals_type, int32_t) \ - REGISTER_ROUND_ROBIN_GENERATE_MASKS_SPLITS(vals_type, int64_t) - -TF_CALL_tstring(REGISTER_ROUND_ROBIN_GENERATE_MASKS) -TF_CALL_bool(REGISTER_ROUND_ROBIN_GENERATE_MASKS) -TF_CALL_float(REGISTER_ROUND_ROBIN_GENERATE_MASKS) -TF_CALL_double(REGISTER_ROUND_ROBIN_GENERATE_MASKS) -TF_CALL_INTEGRAL_TYPES(REGISTER_ROUND_ROBIN_GENERATE_MASKS) - -#undef REGISTER_ROUND_ROBIN_GENERATE_MASKS -#undef REGISTER_ROUND_ROBIN_GENERATE_MASKS_SPLITS - - using RoundRobinTrimOpKernelInstance = - RoundRobinTrimOpKernel; - -#define REGISTER_ROUND_ROBIN_TRIM_SPLITS(vals_type, splits_type) \ - REGISTER_KERNEL_BUILDER(Name(RoundRobinTrimOpKernelInstance::OpName()) \ - .Device(tensorflow::DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tsplits"), \ - RoundRobinTrimOpKernel); - -#define REGISTER_ROUND_ROBIN_TRIM(vals_type) \ - REGISTER_ROUND_ROBIN_TRIM_SPLITS(vals_type, int32_t) \ - REGISTER_ROUND_ROBIN_TRIM_SPLITS(vals_type, int64_t) - -TF_CALL_tstring(REGISTER_ROUND_ROBIN_TRIM) -TF_CALL_bool(REGISTER_ROUND_ROBIN_TRIM) -TF_CALL_float(REGISTER_ROUND_ROBIN_TRIM) -TF_CALL_double(REGISTER_ROUND_ROBIN_TRIM) -TF_CALL_INTEGRAL_TYPES(REGISTER_ROUND_ROBIN_TRIM) - -#undef REGISTER_ROUND_ROBIN_TRIM -#undef REGISTER_ROUND_ROBIN_TRIM_SPLITS - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/round_robin_trimmer_kernel.h b/tensorflow_text/core/kernels/round_robin_trimmer_kernel.h index 69edec748..0383529ab 100644 --- a/tensorflow_text/core/kernels/round_robin_trimmer_kernel.h +++ b/tensorflow_text/core/kernels/round_robin_trimmer_kernel.h @@ -15,29 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_KERNEL_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_KERNEL_H_ -#include "tensorflow/lite/kernels/shim/tf_op_shim.h" -#include "tensorflow_text/core/kernels/round_robin_trimmer_kernel_template.h" - -namespace tensorflow { -namespace text { - -template -class RoundRobinGenerateMasksOpKernel - : public tflite::shim::TfOpKernel { - public: - using tflite::shim::TfOpKernel::TfOpKernel; -}; - -template -class RoundRobinTrimOpKernel - : public tflite::shim::TfOpKernel { - public: - using tflite::shim::TfOpKernel::TfOpKernel; -}; - -} // namespace text -} // namespace tensorflow - +#include "tensorflow/core/kernels/text/round_robin_trimmer_kernel.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_KERNEL_H_ diff --git a/tensorflow_text/core/kernels/round_robin_trimmer_kernel_template.h b/tensorflow_text/core/kernels/round_robin_trimmer_kernel_template.h index 51f17da43..b56b9c0e4 100644 --- a/tensorflow_text/core/kernels/round_robin_trimmer_kernel_template.h +++ b/tensorflow_text/core/kernels/round_robin_trimmer_kernel_template.h @@ -15,310 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_KERNEL_TEMPLATE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_KERNEL_TEMPLATE_H_ -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "tensorflow/lite/kernels/shim/op_kernel.h" -#include "tensorflow/lite/kernels/shim/shape.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow_text/core/kernels/round_robin_trimmer.h" - -namespace tensorflow { -namespace text { - -template -class RoundRobinTrimOp - : public tflite::shim::OpKernelShim { - private: - enum Inputs { - kMaxSeqLength = 0, - kInputValues, - kInputRowSplits - }; - enum Outputs { - kOutputValues = 0, - kOutputRowSplits - }; - int64_t number_of_segments_; - - using typename tflite::shim::OpKernelShim::InitContext; - using typename tflite::shim::OpKernelShim::InvokeContext; - using typename tflite::shim::OpKernelShim::ShapeInferenceContext; - - public: - RoundRobinTrimOp() = default; - static constexpr char kOpName[] = "TFText>RoundRobinTrim"; - static constexpr char kDoc[] = R"doc( - Trims a tensor. - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Attrs(); - - // Inputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Inputs(); - - // Outputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Outputs(); - - // Initializes the op - absl::Status Init(InitContext* context) { - // Attr - SH_RETURN_IF_ERROR(context->GetAttr("N", &number_of_segments_)); - return absl::OkStatus(); - } - - // Runs the operation - absl::Status Invoke(InvokeContext* context); - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* c); -}; - -template -std::vector RoundRobinTrimOp::Attrs() { - return {"N: int >= 1", "T: type", "Tsplits: {int32, int64}"}; -} - -template -std::vector RoundRobinTrimOp::Inputs() { - return {"max_sequence_length: int32", "input_values: N * T", - "input_row_splits: N * Tsplits"}; -} - -template -std::vector RoundRobinTrimOp::Outputs() { - return {"values: N * T", "row_splits: N * Tsplits"}; -} - -template -absl::Status RoundRobinTrimOp::ShapeInference( - ShapeInferenceContext* c) { - using tflite::shim::Shape; - const auto rank_1_shape = Shape({Shape::kUnknownDim}); - int64_t num_segments; - SH_RETURN_IF_ERROR(c->GetAttr("N", &num_segments)); - - SH_ASSIGN_OR_RETURN(const Shape& max_seq_shape, - c->GetInputShape(kMaxSeqLength)); - if (!max_seq_shape.Compatible(Shape({}))) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be a scalar: ", max_seq_shape.ToString())); - } - - for (int i = 0; i < num_segments; ++i) { - SH_ASSIGN_OR_RETURN( - const Shape& values_shape, - c->GetInputShape( - (kInputValues - 1) * num_segments + i + 1)); - if (!values_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be rank 1: ", values_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN( - const Shape& row_splits_shape, - c->GetInputShape( - (kInputRowSplits - 1) * num_segments + i + 1)); - if (!row_splits_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be rank 1: ", row_splits_shape.ToString())); - } - - SH_RETURN_IF_ERROR(c->SetOutputShape( - kOutputRowSplits * num_segments + i, row_splits_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape( - kOutputValues * num_segments + i, rank_1_shape)); - } - - return absl::OkStatus(); -} - -template -absl::Status RoundRobinTrimOp::Invoke(InvokeContext* context) { - // Inputs - SH_ASSIGN_OR_RETURN(const auto msl, context->GetInput(kMaxSeqLength)); - const int max_sequence_length = msl->template AsScalar(); - - std::vector> list_of_values(number_of_segments_); - std::vector> list_of_splits(number_of_segments_); - for (int i = 0; i < number_of_segments_; ++i) { - SH_ASSIGN_OR_RETURN(const auto fv, context->GetInput(kInputValues + i)); - list_of_values[i] = fv->template Data(); - - int row_split_idx = kInputRowSplits + number_of_segments_ - 1 + i; - SH_ASSIGN_OR_RETURN(const auto rs, context->GetInput(row_split_idx)); - list_of_splits[i] = rs->template Data(); - } - - // Compute - RoundRobinTrimmer trimmer(max_sequence_length); - auto [trimmed_vals, trimmed_splits] = trimmer.TrimBatch( - list_of_values, list_of_splits); - - for (int i = 0; i < number_of_segments_; ++i) { - // Allocate output & fill output tensors. - SH_RETURN_IF_ERROR(this->template FillOutputTensor( - trimmed_vals[i], (kOutputValues * number_of_segments_) + i, context)); - SH_RETURN_IF_ERROR( - this->template FillOutputTensor(trimmed_splits[i], - (kOutputRowSplits * number_of_segments_) + i, context)); - } - - return absl::OkStatus(); -} - -template -class RoundRobinGenerateMasksOp - : public tflite::shim::OpKernelShim { - private: - enum Inputs { - kMaxSeqLength = 0, - kInputValues, - kInputRowSplits - }; - enum Outputs { - kOutputMasks = 0 - }; - int64_t number_of_segments_; - - using typename tflite::shim::OpKernelShim::InitContext; - using typename tflite::shim::OpKernelShim::InvokeContext; - using typename tflite::shim::OpKernelShim::ShapeInferenceContext; - - public: - RoundRobinGenerateMasksOp() = default; - static constexpr char kOpName[] = "TFText>RoundRobinGenerateMasks"; - static constexpr char kDoc[] = R"doc( - Generates a mask for trimming a tensor. - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Attrs(); - - // Inputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Inputs(); - - // Outputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector Outputs(); - - // Initializes the op - absl::Status Init(InitContext* context) { - // Attr - SH_RETURN_IF_ERROR(context->GetAttr("N", &number_of_segments_)); - return absl::OkStatus(); - } - - // Runs the operation - absl::Status Invoke(InvokeContext* context); - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* c); -}; - -template -std::vector RoundRobinGenerateMasksOp::Attrs() { - return {"N: int >= 1", "T: type", "Tsplits: {int32, int64}"}; -} - -template -std::vector RoundRobinGenerateMasksOp::Inputs() { - // TODO(broken): use templated value - return {"max_sequence_length: int32", "input_values: N * T", - "input_row_splits: N * Tsplits"}; -} - -template -std::vector RoundRobinGenerateMasksOp::Outputs() { - return {"masks: N * bool"}; -} - -template -absl::Status RoundRobinGenerateMasksOp::ShapeInference( - ShapeInferenceContext* c) { - using tflite::shim::Shape; - const auto rank_1_shape = Shape({Shape::kUnknownDim}); - int64_t num_segments; - SH_RETURN_IF_ERROR(c->GetAttr("N", &num_segments)); - - SH_ASSIGN_OR_RETURN(const Shape& max_seq_shape, - c->GetInputShape(kMaxSeqLength)); - if (!max_seq_shape.Compatible(Shape({}))) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be a scalar: ", max_seq_shape.ToString())); - } - - for (int i = 0; i < num_segments; ++i) { - SH_ASSIGN_OR_RETURN( - const Shape& values_shape, - c->GetInputShape( - (kInputValues - 1) * num_segments + i + 1)); - if (!values_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be rank 1: ", values_shape.ToString())); - } - - SH_ASSIGN_OR_RETURN( - const Shape& row_splits_shape, - c->GetInputShape( - (kInputRowSplits - 1) * num_segments + i + 1)); - if (!row_splits_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be rank 1: ", row_splits_shape.ToString())); - } - - SH_RETURN_IF_ERROR(c->SetOutputShape( - kOutputMasks * num_segments + i, values_shape)); - } - - return absl::OkStatus(); -} - -template -absl::Status RoundRobinGenerateMasksOp::Invoke( - InvokeContext* context) { - // Inputs - SH_ASSIGN_OR_RETURN(const auto msl, context->GetInput(kMaxSeqLength)); - const int max_sequence_length = msl->template AsScalar(); - - std::vector> list_of_splits(number_of_segments_); - for (int i = 0; i < number_of_segments_; ++i) { - int row_split_idx = kInputRowSplits + number_of_segments_ - 1 + i; - SH_ASSIGN_OR_RETURN(const auto rs, context->GetInput(row_split_idx)); - list_of_splits[i] = rs->template Data(); - } - - // Compute - RoundRobinTrimmer trimmer(max_sequence_length); - std::vector> masks = - trimmer.GenerateMasksBatch(list_of_splits); - - for (int i = 0; i < number_of_segments_; ++i) { - // Allocate output & fill output tensors. - SH_RETURN_IF_ERROR(this->template FillOutputTensor(masks[i], - (kOutputMasks * number_of_segments_) + i, context)); - } - - return absl::OkStatus(); -} - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/round_robin_trimmer_kernel_template.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_KERNEL_TEMPLATE_H_ diff --git a/tensorflow_text/core/kernels/round_robin_trimmer_test.cc b/tensorflow_text/core/kernels/round_robin_trimmer_test.cc deleted file mode 100644 index 50c21e32d..000000000 --- a/tensorflow_text/core/kernels/round_robin_trimmer_test.cc +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/round_robin_trimmer.h" - -#include -#include -#include - -#include -#include - -namespace tensorflow { -namespace text { -namespace { - -using ::testing::ElementsAreArray; - -struct TestSpec { - int max_sequence_length; - std::vector vals_a_row_1; - std::vector vals_a_row_2; - std::vector vals_b_row_1; - std::vector vals_b_row_2; - std::vector mask_a_row_1; - std::vector mask_a_row_2; - std::vector mask_b_row_1; - std::vector mask_b_row_2; -}; - -class RoundRobinTrimmerTest : public testing::TestWithParam { - protected: - using Segment = std::vector; - using SegmentBatch = std::vector; - using Splits = std::vector; - using Masks = std::vector; - using MasksBatch = std::vector; - - std::vector GetRaggedInput() { - SegmentBatch a = {input_a_row_1, input_a_row_2}; - SegmentBatch b = {input_b_row_1, input_b_row_2}; - - return {a, b}; - } - - std::vector GetFirstBatch() { - return {input_a_row_1, input_b_row_1}; - } - - std::vector GetSecondBatch() { - return {input_a_row_2, input_b_row_2}; - } - - std::pair, std::vector> GetFlatInput() { - Segment a_vals(input_a_row_1.begin(), input_a_row_1.end()); - a_vals.insert(a_vals.end(), input_a_row_2.begin(), input_a_row_2.end()); - Segment b_vals(input_b_row_1.begin(), input_b_row_1.end()); - b_vals.insert(b_vals.end(), input_b_row_2.begin(), input_b_row_2.end()); - - Splits a_splits = {0}; - a_splits.push_back(input_a_row_1.size()); - a_splits.push_back(a_splits.back() + input_a_row_2.size()); - Splits b_splits = {0}; - b_splits.push_back(input_b_row_1.size()); - b_splits.push_back(b_splits.back() + input_b_row_2.size()); - - std::vector vals = {a_vals, b_vals}; - std::vector splits = {a_splits, b_splits}; - return std::make_pair(vals, splits); - } - - template - std::vector Concat(std::vector a, std::vector b) { - std::vector result(a.begin(), a.end()); - result.insert(result.end(), b.begin(), b.end()); - return result; - } - - private: - const Segment input_a_row_1 = {1, 2, 3, 4, 5}; - const Segment input_a_row_2 = {6, 7}; - const Segment input_b_row_1 = {10, 20, 30, 40, 50}; - const Segment input_b_row_2 = {60, 70}; -}; - -static const std::vector& params = { - { - .max_sequence_length = 10, - .vals_a_row_1 = {1, 2, 3, 4, 5}, - .vals_a_row_2 = {6, 7}, - .vals_b_row_1 = {10, 20, 30, 40, 50}, - .vals_b_row_2 = {60, 70}, - .mask_a_row_1 = {true, true, true, true, true}, - .mask_a_row_2 = {true, true}, - .mask_b_row_1 = {true, true, true, true, true}, - .mask_b_row_2 = {true, true}, - }, - { - .max_sequence_length = 6, - .vals_a_row_1 = {1, 2, 3}, - .vals_a_row_2 = {6, 7}, - .vals_b_row_1 = {10, 20, 30}, - .vals_b_row_2 = {60, 70}, - .mask_a_row_1 = {true, true, true, false, false}, - .mask_a_row_2 = {true, true}, - .mask_b_row_1 = {true, true, true, false, false}, - .mask_b_row_2 = {true, true}, - }, - { - .max_sequence_length = 3, - .vals_a_row_1 = {1, 2}, - .vals_a_row_2 = {6, 7}, - .vals_b_row_1 = {10}, - .vals_b_row_2 = {60}, - .mask_a_row_1 = {true, true, false, false, false}, - .mask_a_row_2 = {true, true}, - .mask_b_row_1 = {true, false, false, false, false}, - .mask_b_row_2 = {true, false}, - }, - { - .max_sequence_length = 0, - .vals_a_row_1 = {}, - .vals_a_row_2 = {}, - .vals_b_row_1 = {}, - .vals_b_row_2 = {}, - .mask_a_row_1 = {false, false, false, false, false}, - .mask_a_row_2 = {false, false}, - .mask_b_row_1 = {false, false, false, false, false}, - .mask_b_row_2 = {false, false}, - } -}; - -TEST_P(RoundRobinTrimmerTest, GenerateMasks) { - TestSpec p = GetParam(); - RoundRobinTrimmer t(p.max_sequence_length); - std::vector masks1 = t.GenerateMasks(GetFirstBatch()); - EXPECT_THAT(masks1[0], ElementsAreArray(p.mask_a_row_1)); - EXPECT_THAT(masks1[1], ElementsAreArray(p.mask_b_row_1)); - std::vector masks2 = t.GenerateMasks(GetSecondBatch()); - EXPECT_THAT(masks2[0], ElementsAreArray(p.mask_a_row_2)); - EXPECT_THAT(masks2[1], ElementsAreArray(p.mask_b_row_2)); -} - -TEST_P(RoundRobinTrimmerTest, GenerateMasks_flat) { - TestSpec p = GetParam(); - RoundRobinTrimmer t(p.max_sequence_length); - std::vector masks = t.GenerateMasksBatch(GetFlatInput().second); - EXPECT_THAT(masks[0], - ElementsAreArray(Concat(p.mask_a_row_1, p.mask_a_row_2))); - EXPECT_THAT(masks[1], - ElementsAreArray(Concat(p.mask_b_row_1, p.mask_b_row_2))); -} - -TEST_P(RoundRobinTrimmerTest, Trim) { - TestSpec p = GetParam(); - RoundRobinTrimmer t(p.max_sequence_length); - std::vector vals1 = GetFirstBatch(); - t.Trim(&vals1); - EXPECT_THAT(vals1[0], ElementsAreArray(p.vals_a_row_1)); - EXPECT_THAT(vals1[1], ElementsAreArray(p.vals_b_row_1)); - std::vector vals2 = GetSecondBatch(); - t.Trim(&vals2); - EXPECT_THAT(vals2[0], ElementsAreArray(p.vals_a_row_2)); - EXPECT_THAT(vals2[1], ElementsAreArray(p.vals_b_row_2)); -} - -TEST_P(RoundRobinTrimmerTest, Trim_flat) { - TestSpec p = GetParam(); - RoundRobinTrimmer t(p.max_sequence_length); - auto [input_vals, input_splits] = GetFlatInput(); - auto [vals, splits] = t.TrimBatch(input_vals, input_splits); - EXPECT_THAT(vals[0], - ElementsAreArray(Concat(p.vals_a_row_1, p.vals_a_row_2))); - EXPECT_THAT(vals[1], - ElementsAreArray(Concat(p.vals_b_row_1, p.vals_b_row_2))); - std::vector result_splits = { 0 }; - result_splits.push_back(p.vals_a_row_1.size()); - result_splits.push_back(p.vals_a_row_1.size() + p.vals_a_row_2.size()); - EXPECT_THAT(splits[0], ElementsAreArray(result_splits)); - result_splits = { 0 }; - result_splits.push_back(p.vals_b_row_1.size()); - result_splits.push_back(p.vals_b_row_1.size() + p.vals_b_row_2.size()); - EXPECT_THAT(splits[1], ElementsAreArray(result_splits)); -} - -TEST_P(RoundRobinTrimmerTest, Trim_int64) { - TestSpec p = GetParam(); - RoundRobinTrimmer t(p.max_sequence_length); - auto [input_vals, input_splits] = GetFlatInput(); - std::vector> input_splits_64(input_splits.size()); - for (int i = 0; i < input_splits.size(); ++i) - input_splits_64[i].insert(input_splits_64[i].end(), - input_splits[i].begin(), input_splits[i].end()); - std::vector> input_vals_64(input_vals.size()); - for (int i = 0; i < input_vals.size(); ++i) - input_vals_64[i].insert(input_vals_64[i].end(), - input_vals[i].begin(), input_vals[i].end()); - auto [vals, splits] = t.TrimBatch(input_vals_64, input_splits_64); - EXPECT_THAT(vals[0], - ElementsAreArray(Concat(p.vals_a_row_1, p.vals_a_row_2))); - EXPECT_THAT(vals[1], - ElementsAreArray(Concat(p.vals_b_row_1, p.vals_b_row_2))); - std::vector result_splits = { 0 }; - result_splits.push_back(p.vals_a_row_1.size()); - result_splits.push_back(p.vals_a_row_1.size() + p.vals_a_row_2.size()); - EXPECT_THAT(splits[0], ElementsAreArray(result_splits)); - result_splits = { 0 }; - result_splits.push_back(p.vals_b_row_1.size()); - result_splits.push_back(p.vals_b_row_1.size() + p.vals_b_row_2.size()); - EXPECT_THAT(splits[1], ElementsAreArray(result_splits)); -} - -INSTANTIATE_TEST_SUITE_P(RoundRobinTrimmerTestSuite, - RoundRobinTrimmerTest, - testing::ValuesIn(params)); - -} // namespace -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/round_robin_trimmer_tflite.cc b/tensorflow_text/core/kernels/round_robin_trimmer_tflite.cc deleted file mode 100644 index a724beba6..000000000 --- a/tensorflow_text/core/kernels/round_robin_trimmer_tflite.cc +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/round_robin_trimmer_tflite.h" - -#include -#include -#include -#include -#include - -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/shim/op_kernel.h" -#include "tensorflow/lite/kernels/shim/tflite_op_shim.h" -#include "tensorflow/lite/kernels/shim/tflite_op_wrapper.h" -#include "tensorflow/lite/mutable_op_resolver.h" -#include "tensorflow_text/core/kernels/round_robin_trimmer_kernel_template.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { -namespace { -const char splits_type[]("Tsplits"), vals_type[]("T"); -} // namespace - -using ::tflite::shim::op_wrapper::Attr; -using ::tflite::shim::op_wrapper::AttrName; -using ::tflite::shim::op_wrapper::OpWrapper; - -template -using GenerateMasksOp = - OpWrapper, ::tensorflow::tstring, float, double, - int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, - int64_t, uint64_t, bool>, - Attr, int32_t, int64_t>>; - -extern "C" void AddRoundRobinGenerateMasks( - tflite::MutableOpResolver* resolver) { - tflite::shim::TfLiteOpKernel::Add(resolver); -} - -template -using TrimOp = - OpWrapper, ::tensorflow::tstring, float, double, - int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, - int64_t, uint64_t, bool>, - Attr, int32_t, int64_t>>; - -extern "C" void AddRoundRobinTrim(tflite::MutableOpResolver* resolver) { - tflite::shim::TfLiteOpKernel::Add(resolver); -} - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow_text/core/kernels/round_robin_trimmer_tflite.h b/tensorflow_text/core/kernels/round_robin_trimmer_tflite.h index 46ffe63b9..6c4459dae 100644 --- a/tensorflow_text/core/kernels/round_robin_trimmer_tflite.h +++ b/tensorflow_text/core/kernels/round_robin_trimmer_tflite.h @@ -15,21 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_TFLITE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_TFLITE_H_ -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/mutable_op_resolver.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -extern "C" void AddRoundRobinGenerateMasks(tflite::MutableOpResolver* resolver); - -extern "C" void AddRoundRobinTrim(::tflite::MutableOpResolver* resolver); - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite +#include "tensorflow/core/kernels/text/round_robin_trimmer_tflite.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_TFLITE_H_ diff --git a/tensorflow_text/core/kernels/sentence_breaking_kernels.cc b/tensorflow_text/core/kernels/sentence_breaking_kernels.cc deleted file mode 100644 index 0f4c34c82..000000000 --- a/tensorflow_text/core/kernels/sentence_breaking_kernels.cc +++ /dev/null @@ -1,267 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "icu4c/source/common/unicode/uchar.h" -#include "icu4c/source/common/unicode/ucnv_err.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow_text/core/kernels/sentence_breaking_utils.h" -#include "tensorflow_text/core/kernels/sentence_fragmenter.h" - -using ::tensorflow::tstring; -using ::tensorflow::errors::InvalidArgument; - -namespace tensorflow { -namespace text { - -// TODO(thuang513): This is copied from unicode_ops.cc, move this to a separate -// util lib in tensorflow and reuse it here instead. -namespace { -// Lifecycle wrapper for UConverter making it easier to use with thread_local. -// TODO(gregbillock): Consider whether to use the higher-level convert API and -// create a specialized fast code path for UTF8. -class WrappedConverter { - public: - WrappedConverter() {} - - ~WrappedConverter() { - if (converter_) { - ucnv_close(converter_); - } - } - - void init(const string& name) { - if (converter_ && name == name_) { - // Note: this reset is not typically needed, but if not done, then in some - // cases the cached converter will maintain state of input endianness - // which isn't valid from input to input in every batched case. - ucnv_reset(converter_); - return; - } - - if (converter_) { - ucnv_close(converter_); - converter_ = nullptr; - name_ = ""; - } - - UErrorCode status = U_ZERO_ERROR; - converter_ = ucnv_open(name.c_str(), &status); - if (U_FAILURE(status)) { - if (converter_) { - ucnv_close(converter_); - converter_ = nullptr; - } - } else { - name_ = name; - } - } - - UConverter* converter_ = nullptr; - string name_; -}; - -struct ErrorOptions { - UChar32 subst = 0xFFFD; - bool elide_replacement = false; - bool replace_control_chars = false; - bool error_on_malformatting = false; -}; - -absl::Status GetErrorOptions(OpKernelConstruction* context, ErrorOptions* out) { - *out = ErrorOptions(); - - string error_policy; - TF_RETURN_IF_ERROR(context->GetAttr("errors", &error_policy)); - - if (error_policy == "replace") { - out->elide_replacement = false; - } else if (error_policy == "ignore") { - out->elide_replacement = true; - } else if (error_policy == "strict") { - out->error_on_malformatting = true; - } else { - return InvalidArgument( - "errors policy must be one of 'strict', 'replace', or 'ignore'"); - } - - int32 replacement_char; - TF_RETURN_IF_ERROR(context->GetAttr("replacement_char", &replacement_char)); - - if (replacement_char >= UCHAR_MIN_VALUE && - replacement_char <= UCHAR_MAX_VALUE) { - out->subst = replacement_char; - } else { - return InvalidArgument("replacement_char out of unicode codepoint range"); - } - - if (context->HasAttr("replace_control_characters")) { - TF_RETURN_IF_ERROR(context->GetAttr("replace_control_characters", - &(out->replace_control_chars))); - } - - return absl::OkStatus(); -} - -inline bool ShouldHandleFormatError(const ErrorOptions& error_options, - UChar32 ch, bool format_error) { - return ((error_options.replace_control_chars && ch <= 0x1F) || format_error); -} - -} // namespace - -class SentenceFragmentsOp : public OpKernel { - public: - explicit SentenceFragmentsOp(OpKernelConstruction* context) - : OpKernel(context) { - OP_REQUIRES_OK(context, GetErrorOptions(context, &error_options_)); - - OP_REQUIRES_OK(context, - context->GetAttr("input_encoding", &input_encoding_)); - // Make a temporary UConverter to ensure it will create without error - // at execution time (and to warm any data caches the converter needs). - // This instance is not used. - std::unique_ptr input_encoder = - std::make_unique(); - input_encoder->init(input_encoding_); - OP_REQUIRES( - context, input_encoder->converter_, - InvalidArgument("Could not create converter for input encoding: " + - input_encoding_)); - } - - void Compute(::tensorflow::OpKernelContext* context) override { -#define DECLARE_AND_VALIDATE_INPUT_VECTOR(name, dtype) \ - const Tensor* name##_tensor; \ - OP_REQUIRES_OK(context, context->input(#name, &name##_tensor)); \ - OP_REQUIRES(context, TensorShapeUtils::IsVector(name##_tensor->shape()), \ - InvalidArgument( \ - absl::StrCat("'", #name, "' must be a vector, got shape: ", \ - name##_tensor->shape().DebugString()))); \ - const auto& name = name##_tensor->vec(); - - DECLARE_AND_VALIDATE_INPUT_VECTOR(row_lengths, int64); - DECLARE_AND_VALIDATE_INPUT_VECTOR(token_start, int64); - DECLARE_AND_VALIDATE_INPUT_VECTOR(token_end, int64); - DECLARE_AND_VALIDATE_INPUT_VECTOR(token_word, tstring); - DECLARE_AND_VALIDATE_INPUT_VECTOR(token_properties, int64); - -#undef DECLARE_AND_VALIDATE_INPUT_TENSOR - - static thread_local std::unique_ptr input_encoder; - if (!input_encoder) { - input_encoder = std::make_unique(); - } - input_encoder->init(input_encoding_); - OP_REQUIRES( - context, input_encoder->converter_, - InvalidArgument("Could not create converter for input encoding: " + - input_encoding_)); - - UConverter* converter = input_encoder->converter_; - UnicodeUtil util(converter); - - int num_elements = 0; - for (int i = 0; i < row_lengths.size(); ++i) { - num_elements += row_lengths(i); - } - OP_REQUIRES(context, - num_elements == token_start.size() && - token_start.size() == token_end.size() && - token_end.size() == token_word.size(), - InvalidArgument(absl::StrCat( - "num_elements(", num_elements, "), token_start(", - token_start.size(), "), token_end(", token_end.size(), - "), token_word(", token_word.size(), - ") must all be the same size."))); - - // Iterate through the text - int token_index = 0; - int num_fragments = 0; - std::vector> fragments; - for (int i = 0; i < row_lengths.size(); ++i) { - std::vector tokens; - Document doc(&tokens); - for (int j = 0; j < row_lengths(i); ++j) { - doc.AddToken( - token_word(token_index), token_start(token_index), - token_end(token_index), Token::SPACE_BREAK, - static_cast(token_properties(token_index))); - ++token_index; - } - - // Find fragments. - SentenceFragmenter fragmenter(&doc, &util); - std::vector frags; - OP_REQUIRES_OK(context, fragmenter.FindFragments(&frags)); - - num_fragments += frags.size(); - fragments.push_back(std::move(frags)); - } - - std::vector fragment_shape; - fragment_shape.push_back(num_fragments); - - std::vector doc_batch_shape; - doc_batch_shape.push_back(fragments.size()); - -#define DECLARE_OUTPUT_TENSOR(name, out_shape) \ - Tensor* name##_tensor = nullptr; \ - OP_REQUIRES_OK(context, context->allocate_output( \ - #name, TensorShape(out_shape), &name##_tensor)); \ - auto name = name##_tensor->vec(); - - DECLARE_OUTPUT_TENSOR(fragment_start, fragment_shape); - DECLARE_OUTPUT_TENSOR(fragment_end, fragment_shape); - DECLARE_OUTPUT_TENSOR(fragment_properties, fragment_shape); - DECLARE_OUTPUT_TENSOR(terminal_punc_token, fragment_shape); - DECLARE_OUTPUT_TENSOR(output_row_lengths, doc_batch_shape); - -#undef DECLARE_OUTPUT_TENSOR - - // output_row_splits should have shape of - // [number of fragments over the entire batch] - int element_index = 0; - // Iterate through all the documents - for (int i = 0; i < fragments.size(); ++i) { - const std::vector& fragments_in_doc = fragments[i]; - // Iterate through all the fragments of a document - for (int j = 0; j < fragments_in_doc.size(); ++j) { - const SentenceFragment& fragment = fragments_in_doc[j]; - fragment_start(element_index) = fragment.start; - fragment_end(element_index) = fragment.limit; - fragment_properties(element_index) = fragment.properties; - terminal_punc_token(element_index) = fragment.terminal_punc_token; - ++element_index; - } - output_row_lengths(i) = fragments_in_doc.size(); - } - } - - private: - string input_encoding_; - ErrorOptions error_options_; -}; - -REGISTER_KERNEL_BUILDER(Name("SentenceFragments").Device(DEVICE_CPU), - SentenceFragmentsOp); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/sentence_breaking_utils.cc b/tensorflow_text/core/kernels/sentence_breaking_utils.cc deleted file mode 100644 index 5bdcc1c02..000000000 --- a/tensorflow_text/core/kernels/sentence_breaking_utils.cc +++ /dev/null @@ -1,238 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/sentence_breaking_utils.h" - -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "icu4c/source/common/unicode/uchar.h" -#include "icu4c/source/common/unicode/utypes.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" - -using ::tensorflow::Status; - -namespace tensorflow { -namespace text { - -absl::Status UnicodeUtil::GetOneUChar(const absl::string_view& input, - bool* has_more_than_one_char, - UChar32* result) const { - UErrorCode status = U_ZERO_ERROR; - const char* source = input.data(); - const char* limit = input.data() + input.length(); - if (!converter_) { - return tensorflow::errors::Internal( - absl::StrCat("Converter has not been initialized!")); - } - *result = ucnv_getNextUChar(converter_, &source, limit, &status); - - if (U_FAILURE(status)) { - return tensorflow::errors::Internal( - absl::StrCat("Failed to decode string, error status=", status)); - } - - if (source != limit) { - *has_more_than_one_char = true; - } else { - *has_more_than_one_char = false; - } - - return absl::OkStatus(); -} - -absl::Status UnicodeUtil::IsTerminalPunc(const absl::string_view& input, - bool* result) const { - *result = false; - const auto& ellipsis_status = IsEllipsis(input, result); - // If there was a error decoding, or if we found an ellipsis, then return. - if (!ellipsis_status.ok()) return ellipsis_status; - if (*result) return absl::OkStatus(); - - bool has_more_than_one_char = false; - UChar32 char_value; - const auto& status = GetOneUChar(input, &has_more_than_one_char, &char_value); - if (!status.ok()) return status; - if (has_more_than_one_char) { - *result = false; - return absl::OkStatus(); - } - - // These are unicode characters that should be considered in this category but - // are not covered by any of the ICU properties. - switch (char_value) { - case 0x055C: // Armenian exclamation mark - case 0x055E: // Armenian question mark - case 0x17d4: // Khmer sign khan - case 0x037E: // Greek question mark - case 0x2026: // ellipsis - *result = true; - return absl::OkStatus(); - } - - USentenceBreak sb_property = static_cast( - u_getIntPropertyValue(char_value, UCHAR_SENTENCE_BREAK)); - *result = sb_property == U_SB_ATERM || sb_property == U_SB_STERM; - return absl::OkStatus(); -} - -absl::Status UnicodeUtil::IsClosePunc(const absl::string_view& input, - bool* result) const { - *result = false; - if (input == "''") { - *result = true; - return absl::OkStatus(); - } - - bool has_more_than_one_char = false; - UChar32 char_value; - const auto& status = GetOneUChar(input, &has_more_than_one_char, &char_value); - if (!status.ok()) return status; - if (has_more_than_one_char) { - *result = false; - return absl::OkStatus(); - } - - // These are unicode characters that should be considered in this category but - // are not covered by any of the ICU properties. - switch (char_value) { - case '>': - case ']': - case '`': - case 64831: // Ornate right parenthesis - case 65282: // fullwidth quotation mark - case 65287: // fullwidth apostrophe - *result = true; - return absl::OkStatus(); - } - - ULineBreak lb_property = static_cast( - u_getIntPropertyValue(char_value, UCHAR_LINE_BREAK)); - - *result = lb_property == U_LB_CLOSE_PUNCTUATION || - lb_property == U_LB_CLOSE_PARENTHESIS || - lb_property == U_LB_QUOTATION; - return absl::OkStatus(); -} - -absl::Status UnicodeUtil::IsOpenParen(const absl::string_view& input, - bool* result) const { - *result = false; - bool has_more_than_one_char = false; - UChar32 char_value; - const auto& status = GetOneUChar(input, &has_more_than_one_char, &char_value); - if (!status.ok()) return status; - if (has_more_than_one_char) { - *result = false; - return absl::OkStatus(); - } - - // These are unicode characters that should be considered in this category but - // are not covered by any of the ICU properties. - switch (char_value) { - case '<': - case 64830: // Ornate left parenthesis - *result = true; - return absl::OkStatus(); - } - - ULineBreak lb_property = static_cast( - u_getIntPropertyValue(char_value, UCHAR_LINE_BREAK)); - *result = lb_property == U_LB_OPEN_PUNCTUATION; - return absl::OkStatus(); -} - -absl::Status UnicodeUtil::IsCloseParen(const absl::string_view& input, - bool* result) const { - *result = false; - bool has_more_than_one_char = false; - UChar32 char_value; - const auto& status = GetOneUChar(input, &has_more_than_one_char, &char_value); - if (!status.ok()) return status; - if (has_more_than_one_char) { - *result = false; - return absl::OkStatus(); - } - - // These are unicode characters that should be considered in this category but - // are not covered by any of the ICU properties. - switch (char_value) { - case '>': - case 64831: // Ornate right parenthesis - *result = true; - return absl::OkStatus(); - } - - ULineBreak lb_property = static_cast( - u_getIntPropertyValue(char_value, UCHAR_LINE_BREAK)); - *result = lb_property == U_LB_CLOSE_PUNCTUATION || - lb_property == U_LB_CLOSE_PARENTHESIS; - return absl::OkStatus(); -} - -absl::Status UnicodeUtil::IsPunctuationWord(const absl::string_view& input, - bool* result) const { - *result = false; - bool has_more_than_one_char = false; - UChar32 char_value; - const auto& status = GetOneUChar(input, &has_more_than_one_char, &char_value); - if (!status.ok()) return status; - if (has_more_than_one_char) { - *result = false; - return absl::OkStatus(); - } - - // These are unicode characters that should be considered in this category but - // are not covered by any of the ICU properties. - switch (char_value) { - case '`': - case '<': - case '>': - case '~': - case 5741: - *result = true; - return absl::OkStatus(); - } - - *result = u_ispunct(char_value) || - u_hasBinaryProperty(char_value, UCHAR_DASH) || - u_hasBinaryProperty(char_value, UCHAR_HYPHEN); - return absl::OkStatus(); -} - -absl::Status UnicodeUtil::IsEllipsis(const absl::string_view& input, - bool* result) const { - *result = false; - if (input == "...") { - *result = true; - return absl::OkStatus(); - } - - bool has_more_than_one_char = false; - UChar32 char_value; - const auto& status = GetOneUChar(input, &has_more_than_one_char, &char_value); - if (!status.ok()) return status; - if (has_more_than_one_char) { - *result = false; - return absl::OkStatus(); - } - - *result = char_value == 0x2026; - return absl::OkStatus(); -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/sentence_breaking_utils.h b/tensorflow_text/core/kernels/sentence_breaking_utils.h index 02534a7e3..b4588afae 100644 --- a/tensorflow_text/core/kernels/sentence_breaking_utils.h +++ b/tensorflow_text/core/kernels/sentence_breaking_utils.h @@ -12,57 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_BREAKING_UTILS_H_ -#define TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_BREAKING_UTILS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_BREAKING_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_BREAKING_UTILS_H_ -#include -#include "absl/strings/string_view.h" -#include "icu4c/source/common/unicode/ucnv.h" -#include "icu4c/source/common/unicode/ucnv_err.h" -#include "icu4c/source/common/unicode/utypes.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/kernels/text/sentence_breaking_utils.h" -namespace tensorflow { -namespace text { - -// A class of utils for identifying certain classes and properties of unicode -// characters. -class UnicodeUtil { - public: - // `converter` not owned. - explicit UnicodeUtil(UConverter* converter) : converter_(converter) {} - - // Returns true iff a string is terminal punctuation. - absl::Status IsTerminalPunc(const absl::string_view& input, - bool* result) const; - - // Returns true iff a string is close punctuation (close quote or close - // paren). - absl::Status IsClosePunc(const absl::string_view& input, bool* result) const; - - // Returns true iff a string is an open paren. - absl::Status IsOpenParen(const absl::string_view& input, bool* result) const; - - // Returns true iff a string is a close paren. - absl::Status IsCloseParen(const absl::string_view& input, bool* result) const; - - // Returns true iff a word is made of punctuation characters only. - absl::Status IsPunctuationWord(const absl::string_view& input, - bool* result) const; - - // Returns true iff a string is an ellipsis token ("..."). - absl::Status IsEllipsis(const absl::string_view& input, bool* result) const; - - private: - absl::Status GetOneUChar(const absl::string_view&, - bool* has_more_than_one_char, UChar32* result) const; - - // not owned. mutable because UConverter contains some internal options and - // buffer. - mutable UConverter* converter_; -}; - -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_BREAKING_UTILS_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_BREAKING_UTILS_H_ diff --git a/tensorflow_text/core/kernels/sentence_breaking_utils_test.cc b/tensorflow_text/core/kernels/sentence_breaking_utils_test.cc deleted file mode 100644 index 7aee97091..000000000 --- a/tensorflow_text/core/kernels/sentence_breaking_utils_test.cc +++ /dev/null @@ -1,576 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/sentence_breaking_utils.h" - -#include -#include -#include - -#include -#include -#include "icu4c/source/common/unicode/uchar.h" -#include "icu4c/source/common/unicode/ucnv.h" -#include "icu4c/source/common/unicode/ucnv_err.h" -#include "icu4c/source/common/unicode/umachine.h" -#include "icu4c/source/common/unicode/uniset.h" -#include "icu4c/source/common/unicode/unistr.h" -#include "icu4c/source/common/unicode/uset.h" -#include "icu4c/source/common/unicode/utypes.h" - -namespace tensorflow { -namespace text { -namespace { - -class SentenceBreakingUtilsTest { - protected: - UConverter* GetUConverter() { - constexpr char name[] = "UTF-8"; - UErrorCode status = U_ZERO_ERROR; - UConverter* converter = ucnv_open(name, &status); - if (U_FAILURE(status)) { - if (converter) { - ucnv_close(converter); - } - return nullptr; - } - return converter; - } -}; - -class SentenceBreakingUtilsParamTest : public SentenceBreakingUtilsTest, - public ::testing::TestWithParam { - protected: - void SetUp() override { - converter_ = SentenceBreakingUtilsTest::GetUConverter(); - ASSERT_NE(converter_, nullptr); - } - - void TearDown() override { ucnv_close(converter_); } - - std::string StringFromUnicodeChar(UChar32 input) { - std::string result; - icu::UnicodeString test_unicode_string(input); - test_unicode_string.toUTF8String(result); - return result; - } - - UConverter* converter_; -}; - -class IsTerminalPuncParamTest : public SentenceBreakingUtilsParamTest {}; - -class IsTerminalPuncTest : public SentenceBreakingUtilsTest, - public ::testing::Test {}; - -const UChar is_terminal_punc_test_cases[] = { - 0x055C, // Armenian exclamation mark - 0x055E, // Armenian question mark - 0x0589, // Armenian full stop - 0x061F, // Arabic question mark - 0x06D4, // Arabic full stop - 0x0700, // Syriabc end of paragraph - 0x0701, // Syriac supralinear full stop - 0x0702, // Syriac sublinear full stop - 0x1362, // Ethiopic full stop - 0x1367, // Ethiopic question mark - 0x1368, // Ethiopic paragraph separator - 0x104A, // Myanmar sign little section - 0x104B, // Myanmar sign section - 0x166E, // Canadian syllabics full stop - 0x17d4, // Khmer sign khan - 0x1803, // Mongolian full stop - 0x1809, // Mongolian Manchu full stop - 0x1944, // Limbu exclamation mark - 0x1945, // Limbu question mark - 0x203C, // double exclamation mark - 0x203D, // interrobang - 0x2047, // double question mark - 0x2048, // question exclamation mark - 0x2049, // exclamation question mark - 0x3002, // ideographic full stop - 0x037E, // Greek question mark - 0xFE52, // small full stop - 0xFE56, // small question mark - 0xFE57, // small exclamation mark - 0xFF01, // fullwidth exclamation mark - 0xFF0E, // fullwidth full stop - 0xFF1F, // fullwidth question mark - 0xFF61, // halfwidth ideographic full stop - 0x2026, // ellipsis - 0x0964, - 0x0965, // Devanagari danda..Devanagari double -}; - -TEST_P(IsTerminalPuncParamTest, IsTerminalPunc) { - UnicodeUtil util(converter_); - std::string test_string = StringFromUnicodeChar(GetParam()); - bool result = false; - EXPECT_TRUE(util.IsTerminalPunc(test_string, &result).ok()); - EXPECT_TRUE(result); -} - -INSTANTIATE_TEST_SUITE_P(IsTerminalPuncTest, IsTerminalPuncParamTest, - ::testing::ValuesIn(is_terminal_punc_test_cases)); - -TEST_F(IsTerminalPuncTest, IsMultiCharEllipseTerminalPunc) { - UConverter* converter = SentenceBreakingUtilsTest::GetUConverter(); - ASSERT_NE(converter, nullptr); - UnicodeUtil util(converter); - std::string test_string = "..."; - bool result; - EXPECT_TRUE(util.IsTerminalPunc(test_string, &result).ok()); - EXPECT_TRUE(result); - ucnv_close(converter); -} - -TEST_F(IsTerminalPuncTest, TestMultiUnicodeChars) { - UConverter* converter = SentenceBreakingUtilsTest::GetUConverter(); - ASSERT_NE(converter, nullptr); - UnicodeUtil util(converter); - std::string test_string = "never gonna let you decode"; - bool result; - EXPECT_TRUE(util.IsTerminalPunc(test_string, &result).ok()); - EXPECT_FALSE(result); - ucnv_close(converter); -} - -TEST_F(IsTerminalPuncTest, TestInvalidConverter) { - UErrorCode status = U_ZERO_ERROR; - UConverter* converter = ucnv_open("cant find me", &status); - UnicodeUtil util(converter); - std::string test_string = "."; - bool result; - EXPECT_FALSE(util.IsTerminalPunc(test_string, &result).ok()); - ucnv_close(converter); -} - -class ClosePuncParamTest : public SentenceBreakingUtilsParamTest {}; - -const UChar close_punc_test_cases[] = { - 0x29, 0x5D, 0x3E, 0x7D, - 0x207E, // superscript right parenthesis - 0x208E, // subscript right parenthesis - 0x27E7, // mathematical right white square bracket - 0x27E9, // mathematical right angle bracket - 0x27EB, // mathematical right double angle bracket - 0x2984, // right white curly bracket - 0x2986, // right white parenthesis - 0x2988, // Z notation right image bracket - 0x298A, // Z notation right binding bracket - 0x298C, // right square bracket with underbar - 0x298E, // right square bracket with tick in top corner - 0x2990, // right square bracket with tick in bottom corner - 0x2992, // right angle bracket with dot - 0x2994, // right arc greater-than bracket - 0x2996, // double right arc less-than bracket - 0x2998, // right black tortoise shell bracket - 0x29D9, // right wiggly fence - 0x29DB, // right double wiggly fence - 0x29FD, // right-pointing curved angle bracket - 0x3009, // CJK right angle bracket - 0x300B, // CJK right double angle bracket - 0x3011, // CJK right black lenticular bracket - 0x3015, // CJK right tortoise shell bracket - 0x3017, // CJK right white lenticular bracket - 0x3019, // CJK right white tortoise shell bracket - 0x301B, // CJK right white square bracket - 0xFD3F, // Ornate right parenthesis - 0xFE5A, // small right parenthesis - 0xFE5C, // small right curly bracket - 0xFF09, // fullwidth right parenthesis - 0xFF3D, // fullwidth right square bracket - 0xFF5D, // fullwidth right curly bracket - 0x27, 0x60, 0x22, - 0xFF07, // fullwidth apostrophe - 0xFF02, // fullwidth quotation mark - 0x2019, // right single quotation mark (English, others) - 0x201D, // right double quotation mark (English, others) - 0x2018, // left single quotation mark (Czech, German, Slovak) - 0x201C, // left double quotation mark (Czech, German, Slovak) - 0x203A, // single right-pointing angle quotation mark (French, others) - 0x00BB, // right-pointing double angle quotation mark (French, others) - 0x2039, // single left-pointing angle quotation mark (Slovenian, others) - 0x00AB, // left-pointing double angle quotation mark (Slovenian, others) - 0x300D, // right corner bracket (East Asian languages) - 0xfe42, // presentation form for vertical right corner bracket - 0xFF63, // halfwidth right corner bracket (East Asian languages) - 0x300F, // right white corner bracket (East Asian languages) - 0xfe44, // presentation form for vertical right white corner bracket - 0x301F, // low double prime quotation mark (East Asian languages) - 0x301E, // close double prime (East Asian languages written horizontally) -}; - -TEST_P(ClosePuncParamTest, IsClosePunc) { - UnicodeUtil util(converter_); - std::string test_string = StringFromUnicodeChar(GetParam()); - bool result = false; - EXPECT_TRUE(util.IsClosePunc(test_string, &result).ok()); - EXPECT_TRUE(result); -} - -INSTANTIATE_TEST_SUITE_P(IsClosePuncParamTest, ClosePuncParamTest, - ::testing::ValuesIn(close_punc_test_cases)); - -class OpenParenParamTest : public SentenceBreakingUtilsParamTest {}; - -const UChar open_paren_test_cases[] = { - '(', '[', '<', '{', - 0x207D, // superscript left parenthesis - 0x208D, // subscript left parenthesis - 0x27E6, // mathematical left white square bracket - 0x27E8, // mathematical left angle bracket - 0x27EA, // mathematical left double angle bracket - 0x2983, // left white curly bracket - 0x2985, // left white parenthesis - 0x2987, // Z notation left image bracket - 0x2989, // Z notation left binding bracket - 0x298B, // left square bracket with underbar - 0x298D, // left square bracket with tick in top corner - 0x298F, // left square bracket with tick in bottom corner - 0x2991, // left angle bracket with dot - 0x2993, // left arc less-than bracket - 0x2995, // double left arc greater-than bracket - 0x2997, // left black tortoise shell bracket - 0x29D8, // left wiggly fence - 0x29DA, // left double wiggly fence - 0x29FC, // left-pointing curved angle bracket - 0x3008, // CJK left angle bracket - 0x300A, // CJK left double angle bracket - 0x3010, // CJK left black lenticular bracket - 0x3014, // CJK left tortoise shell bracket - 0x3016, // CJK left white lenticular bracket - 0x3018, // CJK left white tortoise shell bracket - 0x301A, // CJK left white square bracket - 0xFD3E, // Ornate left parenthesis - 0xFE59, // small left parenthesis - 0xFE5B, // small left curly bracket - 0xFF08, // fullwidth left parenthesis - 0xFF3B, // fullwidth left square bracket - 0xFF5B, // fullwidth left curly bracket -}; - -TEST_P(OpenParenParamTest, IsOpenParen) { - UnicodeUtil util(converter_); - std::string test_string = StringFromUnicodeChar(GetParam()); - bool result = false; - EXPECT_TRUE(util.IsOpenParen(test_string, &result).ok()); - EXPECT_TRUE(result); -} - -INSTANTIATE_TEST_SUITE_P(IsOpenParenParamTest, OpenParenParamTest, - ::testing::ValuesIn(open_paren_test_cases)); - -class CloseParenParamTest : public SentenceBreakingUtilsParamTest {}; - -const UChar close_paren_test_cases[] = { - ')', ']', '>', '}', - 0x207E, // superscript right parenthesis - 0x208E, // subscript right parenthesis - 0x27E7, // mathematical right white square bracket - 0x27E9, // mathematical right angle bracket - 0x27EB, // mathematical right double angle bracket - 0x2984, // right white curly bracket - 0x2986, // right white parenthesis - 0x2988, // Z notation right image bracket - 0x298A, // Z notation right binding bracket - 0x298C, // right square bracket with underbar - 0x298E, // right square bracket with tick in top corner - 0x2990, // right square bracket with tick in bottom corner - 0x2992, // right angle bracket with dot - 0x2994, // right arc greater-than bracket - 0x2996, // double right arc less-than bracket - 0x2998, // right black tortoise shell bracket - 0x29D9, // right wiggly fence - 0x29DB, // right double wiggly fence - 0x29FD, // right-pointing curved angle bracket - 0x3009, // CJK right angle bracket - 0x300B, // CJK right double angle bracket - 0x3011, // CJK right black lenticular bracket - 0x3015, // CJK right tortoise shell bracket - 0x3017, // CJK right white lenticular bracket - 0x3019, // CJK right white tortoise shell bracket - 0x301B, // CJK right white square bracket - 0xFD3F, // Ornate right parenthesis - 0xFE5A, // small right parenthesis - 0xFE5C, // small right curly bracket - 0xFF09, // fullwidth right parenthesis - 0xFF3D, // fullwidth right square bracket - 0xFF5D, // fullwidth right curly bracket -}; - -TEST_P(CloseParenParamTest, IsCloseParen) { - UnicodeUtil util(converter_); - std::string test_string = StringFromUnicodeChar(GetParam()); - bool result = false; - EXPECT_TRUE(util.IsCloseParen(test_string, &result).ok()); - EXPECT_TRUE(result); -} - -INSTANTIATE_TEST_SUITE_P(IsCloseParenParamTest, CloseParenParamTest, - ::testing::ValuesIn(close_paren_test_cases)); - -class IsPunctuationWordParamTest : public SentenceBreakingUtilsParamTest {}; - -const UChar punc_word_test_cases[] = { - '(', '[', '<', '{', - 0x207D, // superscript left parenthesis - 0x208D, // subscript left parenthesis - 0x27E6, // mathematical left white square bracket - 0x27E8, // mathematical left angle bracket - 0x27EA, // mathematical left double angle bracket - 0x2983, // left white curly bracket - 0x2985, // left white parenthesis - 0x2987, // Z notation left image bracket - 0x2989, // Z notation left binding bracket - 0x298B, // left square bracket with underbar - 0x298D, // left square bracket with tick in top corner - 0x298F, // left square bracket with tick in bottom corner - 0x2991, // left angle bracket with dot - 0x2993, // left arc less-than bracket - 0x2995, // double left arc greater-than bracket - 0x2997, // left black tortoise shell bracket - 0x29D8, // left wiggly fence - 0x29DA, // left double wiggly fence - 0x29FC, // left-pointing curved angle bracket - 0x3008, // CJK left angle bracket - 0x300A, // CJK left double angle bracket - 0x3010, // CJK left black lenticular bracket - 0x3014, // CJK left tortoise shell bracket - 0x3016, // CJK left white lenticular bracket - 0x3018, // CJK left white tortoise shell bracket - 0x301A, // CJK left white square bracket - 0xFD3E, // Ornate left parenthesis - 0xFE59, // small left parenthesis - 0xFE5B, // small left curly bracket - 0xFF08, // fullwidth left parenthesis - 0xFF3B, // fullwidth left square bracket - 0xFF5B, // fullwidth left curly bracket - '"', '\'', '`', - 0xFF07, // fullwidth apostrophe - 0xFF02, // fullwidth quotation mark - 0x2018, // left single quotation mark (English, others) - 0x201C, // left double quotation mark (English, others) - 0x201B, // single high-reveresed-9 quotation mark (PropList.txt) - 0x201A, // single low-9 quotation mark (Czech, German, Slovak) - 0x201E, // double low-9 quotation mark (Czech, German, Slovak) - 0x201F, // double high-reversed-9 quotation mark (PropList.txt) - 0x2019, // right single quotation mark (Danish, Finnish, Swedish, Norw.) - 0x201D, // right double quotation mark (Danish, Finnish, Swedish, Norw.) - 0x2039, // single left-pointing angle quotation mark (French, others) - 0x00AB, // left-pointing double angle quotation mark (French, others) - 0x203A, // single right-pointing angle quotation mark (Slovenian, others) - 0x00BB, // right-pointing double angle quotation mark (Slovenian, others) - 0x300C, // left corner bracket (East Asian languages) - 0xFE41, // presentation form for vertical left corner bracket - 0xFF62, // halfwidth left corner bracket (East Asian languages) - 0x300E, // left white corner bracket (East Asian languages) - 0xFE43, // presentation form for vertical left white corner bracket - 0x301D, // reversed double prime quotation mark (East Asian langs, horiz.) - ')', ']', '>', '}', - 0x207E, // superscript right parenthesis - 0x208E, // subscript right parenthesis - 0x27E7, // mathematical right white square bracket - 0x27E9, // mathematical right angle bracket - 0x27EB, // mathematical right double angle bracket - 0x2984, // right white curly bracket - 0x2986, // right white parenthesis - 0x2988, // Z notation right image bracket - 0x298A, // Z notation right binding bracket - 0x298C, // right square bracket with underbar - 0x298E, // right square bracket with tick in top corner - 0x2990, // right square bracket with tick in bottom corner - 0x2992, // right angle bracket with dot - 0x2994, // right arc greater-than bracket - 0x2996, // double right arc less-than bracket - 0x2998, // right black tortoise shell bracket - 0x29D9, // right wiggly fence - 0x29DB, // right double wiggly fence - 0x29FD, // right-pointing curved angle bracket - 0x3009, // CJK right angle bracket - 0x300B, // CJK right double angle bracket - 0x3011, // CJK right black lenticular bracket - 0x3015, // CJK right tortoise shell bracket - 0x3017, // CJK right white lenticular bracket - 0x3019, // CJK right white tortoise shell bracket - 0x301B, // CJK right white square bracket - 0xFD3F, // Ornate right parenthesis - 0xFE5A, // small right parenthesis - 0xFE5C, // small right curly bracket - 0xFF09, // fullwidth right parenthesis - 0xFF3D, // fullwidth right square bracket - 0xFF5D, // fullwidth right curly bracket - '\'', '"', '`', - 0xFF07, // fullwidth apostrophe - 0xFF02, // fullwidth quotation mark - 0x2019, // right single quotation mark (English, others) - 0x201D, // right double quotation mark (English, others) - 0x2018, // left single quotation mark (Czech, German, Slovak) - 0x201C, // left double quotation mark (Czech, German, Slovak) - 0x203A, // single right-pointing angle quotation mark (French, others) - 0x00BB, // right-pointing double angle quotation mark (French, others) - 0x2039, // single left-pointing angle quotation mark (Slovenian, others) - 0x00AB, // left-pointing double angle quotation mark (Slovenian, others) - 0x300D, // right corner bracket (East Asian languages) - 0xfe42, // presentation form for vertical right corner bracket - 0xFF63, // halfwidth right corner bracket (East Asian languages) - 0x300F, // right white corner bracket (East Asian languages) - 0xfe44, // presentation form for vertical right white corner bracket - 0x301F, // low double prime quotation mark (East Asian languages) - 0x301E, // close double prime (East Asian languages written horizontally) - 0x00A1, // Spanish inverted exclamation mark - 0x00BF, // Spanish inverted question mark - '.', '!', '?', - 0x055C, // Armenian exclamation mark - 0x055E, // Armenian question mark - 0x0589, // Armenian full stop - 0x061F, // Arabic question mark - 0x06D4, // Arabic full stop - 0x0700, // Syriac end of paragraph - 0x0701, // Syriac supralinear full stop - 0x0702, // Syriac sublinear full stop - 0x0964, // Devanagari danda..Devanagari double danda - 0x0965, - 0x1362, // Ethiopic full stop - 0x1367, // Ethiopic question mark - 0x1368, // Ethiopic paragraph separator - 0x104A, // Myanmar sign little section - 0x104B, // Myanmar sign section - 0x166E, // Canadian syllabics full stop - 0x17d4, // Khmer sign khan - 0x1803, // Mongolian full stop - 0x1809, // Mongolian Manchu full stop - 0x1944, // Limbu exclamation mark - 0x1945, // Limbu question mark - 0x203C, // double exclamation mark - 0x203D, // interrobang - 0x2047, // double question mark - 0x2048, // question exclamation mark - 0x2049, // exclamation question mark - 0x3002, // ideographic full stop - 0x037E, // Greek question mark - 0xFE52, // small full stop - 0xFE56, // small question mark - 0xFE57, // small exclamation mark - 0xFF01, // fullwidth exclamation mark - 0xFF0E, // fullwidth full stop - 0xFF1F, // fullwidth question mark - 0xFF61, // halfwidth ideographic full stop - 0x2026, // ellipsis - 0x30fb, // Katakana middle dot - 0xff65, // halfwidth Katakana middle dot - 0x2040, // character tie - '-', '~', - 0x058a, // Armenian hyphen - 0x1806, // Mongolian todo soft hyphen - 0x2010, // hyphen..horizontal bar - 0x2011, 0x2012, 0x2013, 0x2014, 0x2015, - 0x2053, // swung dash -- from Table 6-3 of Unicode book - 0x207b, // superscript minus - 0x208b, // subscript minus - 0x2212, // minus sign - 0x301c, // wave dash - 0x3030, // wavy dash - 0xfe31, // presentation form for vertical em dash..en dash - 0xfe32, - 0xfe58, // small em dash - 0xfe63, // small hyphen-minus - 0xff0d, // fullwidth hyphen-minus - ',', ':', ';', - 0x00b7, // middle dot - 0x0387, // Greek ano teleia - 0x05c3, // Hebrew punctuation sof pasuq - 0x060c, // Arabic comma - 0x061b, // Arabic semicolon - 0x066b, // Arabic decimal separator - 0x066c, // Arabic thousands separator - 0x0703, // Syriac contraction and others - 0x0704, 0x0705, 0x0706, 0x0707, 0x0708, 0x0709, 0x70a, - 0x070c, // Syric harklean metobelus - 0x0e5a, // Thai character angkhankhu - 0x0e5b, // Thai character khomut - 0x0f08, // Tibetan mark sbrul shad - 0x0f0d, // Tibetan mark shad..Tibetan mark rgya gram shad - 0x0f0e, 0x0f0f, 0x0f10, 0x0f11, 0x0f12, - 0x1361, // Ethiopic wordspace - 0x1363, // other Ethiopic chars - 0x1364, 0x1365, 0x1366, - 0x166d, // Canadian syllabics chi sign - 0x16eb, // Runic single punctuation..Runic cross punctuation - 0x16ed, - 0x17d5, // Khmer sign camnuc pii huuh and other - 0x17d6, - 0x17da, // Khmer sign koomut - 0x1802, // Mongolian comma - 0x1804, // Mongolian four dots and other - 0x1805, - 0x1808, // Mongolian manchu comma - 0x3001, // ideographic comma - 0xfe50, // small comma and others - 0xfe51, - 0xfe54, // small semicolon and other - 0xfe55, - 0xff0c, // fullwidth comma - 0xff0e, // fullwidth stop..fullwidth solidus - 0xff0f, - 0xff1a, // fullwidth colon..fullwidth semicolon - 0xff1b, - 0xff64, // halfwidth ideographic comma - 0x2016, // double vertical line - 0x2032, 0x2033, - 0x2034, // prime..triple prime - 0xfe61, // small asterisk - 0xfe68, // small reverse solidus - 0xff3c, // fullwidth reverse solidus -}; - -TEST_P(IsPunctuationWordParamTest, IsPunctuation) { - UnicodeUtil util(converter_); - std::string test_string = StringFromUnicodeChar(GetParam()); - bool result = false; - EXPECT_TRUE(util.IsPunctuationWord(test_string, &result).ok()); - EXPECT_TRUE(result); -} - -INSTANTIATE_TEST_SUITE_P(IsPuncWordParamTest, IsPunctuationWordParamTest, - ::testing::ValuesIn(punc_word_test_cases)); - -class IsEllipsisTest : public SentenceBreakingUtilsTest, - public ::testing::Test { - protected: - void SetUp() override { - converter_ = SentenceBreakingUtilsTest::GetUConverter(); - } - - void TearDown() override { ucnv_close(converter_); } - - UConverter* converter_; -}; - -TEST_F(IsEllipsisTest, IsEllipsis) { - UnicodeUtil util(converter_); - bool result = false; - EXPECT_TRUE(util.IsEllipsis("...", &result).ok()); - EXPECT_TRUE(result); - - EXPECT_TRUE(util.IsEllipsis("…", &result).ok()); - EXPECT_TRUE(result); - - EXPECT_TRUE(util.IsEllipsis("@", &result).ok()); - EXPECT_FALSE(result); -} - -} // namespace -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/sentence_fragmenter.cc b/tensorflow_text/core/kernels/sentence_fragmenter.cc deleted file mode 100644 index c336b5cfa..000000000 --- a/tensorflow_text/core/kernels/sentence_fragmenter.cc +++ /dev/null @@ -1,443 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/sentence_fragmenter.h" -#include -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow_text/core/kernels/sentence_breaking_utils.h" - -using ::tensorflow::Status; - -namespace tensorflow { -namespace text { -namespace { - -// Sets a property of a sentence fragment. -void SetFragmentProperty(SentenceFragment::Property property, - SentenceFragment *fragment) { - fragment->properties = fragment->properties | property; -} - -// Returns true iff a token has any of the given properties. -bool TokenHasProperty(uint32 properties, const Token &token) { - return token.text_properties() & properties; -} - -// Returns true iff a token has the ACRONYM text property and token.word() -// ends with a period. -bool IsPeriodSeparatedAcronym(const Token &token) { - return TokenHasProperty(Token::ACRONYM, token) && - (!token.word().empty() && token.word().back() == '.'); -} - -// Returns true iff the token can appear after a space in a sentence-terminal -// token sequence. -absl::Status SpaceAllowedBeforeToken(const UnicodeUtil *util, - const Token &token, bool *result) { - const tstring &word = token.word(); - bool is_ellipsis = false; - TF_RETURN_IF_ERROR(util->IsEllipsis(word, &is_ellipsis)); - - bool is_terminal_punc = false; - TF_RETURN_IF_ERROR(util->IsTerminalPunc(word, &is_terminal_punc)); - - bool is_close_paren = false; - TF_RETURN_IF_ERROR(util->IsCloseParen(word, &is_close_paren)); - - *result = (TokenHasProperty(Token::EMOTICON, token) || - (is_ellipsis || is_terminal_punc || is_close_paren)); - return absl::OkStatus(); -} -} // namespace - -class SentenceFragmenter::FragmentBoundaryMatch { - public: - FragmentBoundaryMatch() { - Reset(); - } - - // Goes to initial state. - void Reset() { - state_ = INITIAL_STATE; - first_terminal_punc_index_ = -1; - first_close_punc_index_ = -1; - limit_index_ = -1; - } - - // Follows the state transition for the token at the given index. Returns - // true for success, or false if there was no valid transition. - absl::Status Advance(const UnicodeUtil *util, const Document &document, - int index, bool *result) { - const Token &token = document.tokens()[index]; - const tstring &word = token.word(); - bool no_transition = false; - - bool is_terminal_punc = false; - TF_RETURN_IF_ERROR(util->IsTerminalPunc(word, &is_terminal_punc)); - - bool is_ellipsis = false; - TF_RETURN_IF_ERROR(util->IsEllipsis(word, &is_ellipsis)); - - bool is_close_punc = false; - TF_RETURN_IF_ERROR(util->IsClosePunc(word, &is_close_punc)); - - switch (state_) { - case INITIAL_STATE: - if (is_terminal_punc || is_ellipsis || - IsPeriodSeparatedAcronym(token) || - TokenHasProperty(Token::EMOTICON, token)) { - first_terminal_punc_index_ = index; - state_ = COLLECTING_TERMINAL_PUNC; - } - break; - case COLLECTING_TERMINAL_PUNC: - - if (is_terminal_punc || is_ellipsis || - TokenHasProperty(Token::EMOTICON, token)) { - // Stay in COLLECTING_TERMINAL_PUNC state. - } else if (is_close_punc) { - first_close_punc_index_ = index; - state_ = COLLECTING_CLOSE_PUNC; - } else { - no_transition = true; - } - break; - case COLLECTING_CLOSE_PUNC: - if (is_close_punc || is_ellipsis || - TokenHasProperty(Token::EMOTICON, token)) { - // Stay in COLLECTING_CLOSE_PUNC state. We effectively ignore - // emoticons and ellipses and continue to accept closing punctuation - // after them. - } else { - no_transition = true; - } - break; - } - - if (no_transition) { - *result = false; - return absl::OkStatus(); - } else { - limit_index_ = index + 1; - if (state_ == COLLECTING_TERMINAL_PUNC) { - // We've gotten terminal punctuation, but no close punctuation yet. - first_close_punc_index_ = limit_index_; - } - *result = true; - return absl::OkStatus(); - } - } - - // Returns true iff we have matched at least one terminal punctuation - // character. - bool GotTerminalPunc() const { - return first_terminal_punc_index_ >= 0; - } - - // Field accessors. - int first_terminal_punc_index() const { - return first_terminal_punc_index_; - } - int first_close_punc_index() const { - return first_close_punc_index_; - } - int limit_index() const { - return limit_index_; - } - - private: - // Match state. - enum MatchState { - INITIAL_STATE = 0, - COLLECTING_TERMINAL_PUNC, - COLLECTING_CLOSE_PUNC - }; - MatchState state_ = INITIAL_STATE; - - // First terminal punctuation mark matched; may be an acronym. - // -1 for not found. - int first_terminal_punc_index_ = -1; - - // First closing punctuation mark matched. -1 for not found. - int first_close_punc_index_ = -1; - - // First token after the terminal sequence. - int limit_index_ = -1; -}; - -absl::Status SentenceFragmenter::FindFragments( - std::vector *result) { - // Partition tokens into sentence fragments. - for (int i_start = 0; i_start < document_->tokens().size();) { - SentenceFragment fragment; - - // Match regexp for fragment boundary. - FragmentBoundaryMatch match; - TF_RETURN_IF_ERROR(FindNextFragmentBoundary(i_start, &match)); - - // Update 'latest_open_paren_is_sentential_' for the tokens in this - // fragment. - TF_RETURN_IF_ERROR( - UpdateLatestOpenParenForFragment(i_start, match.limit_index())); - - // Add a new sentence fragment up to this boundary. - TF_RETURN_IF_ERROR(FillInFragmentFields(i_start, match, &fragment)); - - result->push_back(std::move(fragment)); - i_start = match.limit_index(); - } - return absl::OkStatus(); -} - -// This method is essentially a control layer on top of a simple state machine -// that matches an end-of-fragment regexp. This method finds the next token to -// feed to the state machine, and handles embedded whitespace. The main -// complexity is that a space may delimit end-of-match, or be embedded in the -// termination sequence. When we encounter a space, we record the match found so -// far, but also continue matching. We return the longer match if it succeeds, -// else fall back to the earlier one. Note that the lookahead can incur at most -// 2n cost. -// -// E.g., suppose we're given: x? !!!y. We encounter the space after "x?" and -// have to look ahead all the way to "y" before realizing that the longer match -// fails. We put a fragment boundary after "x?", and next time around, we again -// scan "!!!" looking for a fragment boundary. Since we failed to find one last -// time, we'll fail again this time and therefore continue past "y" to find the -// next boundary. We will not try to scan "!!!" a third time. -absl::Status SentenceFragmenter::FindNextFragmentBoundary( - int i_start, SentenceFragmenter::FragmentBoundaryMatch *result) const { - FragmentBoundaryMatch current_match; - FragmentBoundaryMatch previous_match; - - for (int i = i_start; i < static_cast(document_->tokens().size()); ++i) { - const auto &token = document_->tokens()[i]; - if (current_match.GotTerminalPunc() && i > i_start && - token.break_level() >= Token::SPACE_BREAK) { - // Got terminal punctuation and a space delimiter, so match is valid. - bool space_allowed_before_token = false; - TF_RETURN_IF_ERROR( - SpaceAllowedBeforeToken(util_, token, &space_allowed_before_token)); - if (space_allowed_before_token) { - // Remember this match. Try to extend it. - previous_match = current_match; - } else { - // Stop here. We're not allowed to extend the match in this case. - break; - } - } - bool got_transition = false; - TF_RETURN_IF_ERROR( - current_match.Advance(util_, *document_, i, &got_transition)); - if (!got_transition) { - if (previous_match.GotTerminalPunc()) { - // Extension failed. Return previous match. - *result = previous_match; - return absl::OkStatus(); - } else { - // Start matching again from scratch. - current_match.Reset(); - - // Reprocess current token since it might be terminal punctuation. No - // infinite loop, because can't be "no transition" from INITIAL_STATE. - --i; - } - } - } - *result = current_match; - return absl::OkStatus(); -} - -// Keep track of whether the latest open parenthesis seen so far appears to be -// sentence-initial. This is useful because if it is *non-sentence-initial*, -// then any terminal punctuation before the corresponding close paren is -// probably not a sentence boundary. Example: -// -// Mushrooms (they're fungi!!) are delicious. -// (Mushrooms are fungi!!) -// -// In the first case, the open paren is non-sentence-initial, and therefore -// the "!!)" is not a sentence boundary. In the second case, the open paren *is* -// sentence-initial, and so the "!!)" is a sentence boundary. -// -// Of course, we don't know true sentence boundaries, so we make the -// approximation that an open paren is sentence-initial iff it is -// fragment-initial. This will be wrong if the open paren occurs after terminal -// punctuation that turns out not to be a sentence boundary, e.g., -// "Yahoo! (known for search, etc.) blah", but this is not expected to happen -// often. -absl::Status SentenceFragmenter::UpdateLatestOpenParenForFragment(int i_start, - int i_end) { - for (int i = i_end; i > i_start; --i) { - const auto &token = document_->tokens()[i - 1]; - bool is_open_paren = false; - TF_RETURN_IF_ERROR(util_->IsOpenParen(token.word(), &is_open_paren)); - if (is_open_paren) { - // Make the approximation that this open paren is sentence-initial iff it - // is fragment-initial. - latest_open_paren_is_sentential_ = (i - 1 == i_start); - break; - } - } - - return absl::OkStatus(); -} - -absl::Status SentenceFragmenter::FillInFragmentFields( - int i_start, const FragmentBoundaryMatch &match, - SentenceFragment *fragment) const { - // Set the fragment's boundaries. - fragment->start = i_start; - fragment->limit = match.limit_index(); - - // Set the fragment's properties. - if (match.GotTerminalPunc()) { - // TERMINAL_PUNC. - SetFragmentProperty(SentenceFragment::TERMINAL_PUNC, fragment); - int terminal_punc_index = -1; - TF_RETURN_IF_ERROR( - GetAdjustedFirstTerminalPuncIndex(match, &terminal_punc_index)); - bool has_unattachable_terminal_punc = false; - TF_RETURN_IF_ERROR( - HasUnattachableTerminalPunc(match, &has_unattachable_terminal_punc)); - bool has_close_paren = false; - TF_RETURN_IF_ERROR(HasCloseParen(match, &has_close_paren)); - - fragment->terminal_punc_token = terminal_punc_index; - // MULTIPLE_TERMINAL_PUNC. - if (has_unattachable_terminal_punc) { - SetFragmentProperty(SentenceFragment::MULTIPLE_TERMINAL_PUNC, fragment); - } - - // HAS_CLOSE_PAREN & HAS_SENTENTIAL_CLOSE_PAREN. - if (has_close_paren) { - SetFragmentProperty(SentenceFragment::HAS_CLOSE_PAREN, fragment); - - if (latest_open_paren_is_sentential_) { - SetFragmentProperty(SentenceFragment::HAS_SENTENTIAL_CLOSE_PAREN, - fragment); - } - } - } - - return absl::OkStatus(); -} - -// The standard first terminal punctuation index is just -// match.first_terminal_punc_index(). But if there is an ambiguous terminal -// punctuation mark (ellipsis) followed by an unambiguous one (.!?), then we -// treat the ellipsis as part of the sentence, and return the index of the first -// unambiguous punctuation mark after it. Example: -// -// He agreed...! -// -// We treat "!" as the first terminal punctuation mark; the ellipsis acts as -// left context. -absl::Status SentenceFragmenter::GetAdjustedFirstTerminalPuncIndex( - const FragmentBoundaryMatch &match, int *result) const { - // Get terminal punctuation span. - int i1 = match.first_terminal_punc_index(); - if (i1 < 0) { - *result = i1; - return absl::OkStatus(); - } - int i2 = match.first_close_punc_index(); - - for (int i = i2; i > i1; --i) { - const auto &token = document_->tokens()[i - 1]; - bool is_ellipsis = false; - TF_RETURN_IF_ERROR(util_->IsEllipsis(token.word(), &is_ellipsis)); - if (is_ellipsis || TokenHasProperty(Token::EMOTICON, token)) { - if (i == i2) { - // Ellipsis is last terminal punctuation mark. No adjustment. - *result = i1; - return absl::OkStatus(); - } else { - // Ellipsis is not the last terminal punctuation mark. Return the index - // of the terminal punctuation mark after it. - *result = i; // current token = i - 1 - return absl::OkStatus(); - } - } - } - - // No ellipsis. - *result = i1; - return absl::OkStatus(); -} - -// Example of an an "unattachable" terminal punctuation mark: -// -// He agreed!? -// -// The "?" is "unattachable" in that it can't be part of the word "agreed" -// because of the intervening "!", and therefore strongly suggests this is a -// true sentence boundary. The terminal punctuation mark must be unambiguous -// (.!?), as ambiguous ones (ellipsis/emoticon) do not necessarily imply a -// sentence boundary. -absl::Status SentenceFragmenter::HasUnattachableTerminalPunc( - const FragmentBoundaryMatch &match, bool *result) const { - *result = false; - // Get terminal punctuation span. - int i1 = match.first_terminal_punc_index(); - if (i1 < 0) { - *result = false; - return absl::OkStatus(); - } - int i2 = match.first_close_punc_index(); - - // Iterate over the second and later punctuation marks. - for (int i = i1 + 1; i < i2; ++i) { - const auto &token = document_->tokens()[i]; - bool is_punctuation = false; - TF_RETURN_IF_ERROR(util_->IsPunctuationWord(token.word(), &is_punctuation)); - bool is_ellipsis = false; - TF_RETURN_IF_ERROR(util_->IsEllipsis(token.word(), &is_ellipsis)); - if (is_punctuation && !is_ellipsis && - !TokenHasProperty(Token::EMOTICON, token)) { - // Found an unattachable, unambiguous terminal punctuation mark. - *result = true; - return absl::OkStatus(); - } - } - - *result = false; - return absl::OkStatus(); -} - -absl::Status SentenceFragmenter::HasCloseParen( - const FragmentBoundaryMatch &match, bool *result) const { - *result = false; - // Get close punctuation span. - int i1 = match.first_close_punc_index(); - if (i1 < 0) { - *result = false; - return absl::OkStatus(); - } - int i2 = match.limit_index(); - - for (int i = i1; i < i2; ++i) { - const auto &token = document_->tokens()[i]; - bool is_close_paren = false; - TF_RETURN_IF_ERROR(util_->IsCloseParen(token.word(), &is_close_paren)); - if (is_close_paren) { - *result = true; - return absl::OkStatus(); - } - } - *result = false; - return absl::OkStatus(); -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/sentence_fragmenter.h b/tensorflow_text/core/kernels/sentence_fragmenter.h index 25f1038e1..c30f8ad1a 100644 --- a/tensorflow_text/core/kernels/sentence_fragmenter.h +++ b/tensorflow_text/core/kernels/sentence_fragmenter.h @@ -12,213 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -// A class to split up a document into sentence fragments. A sentence -// fragment is a token sequence whose end is potentially an end-of-sentence. -// -// Example: -// -// Document text: -// John said, "I.B.M. went up 5 points today." -// -// SentenceFragments: -// (1) John said, "I.B.M. -// (2) went up 5 points today." -// -// Fragment boundaries are induced by punctuation and paragraph breaks. - -#ifndef TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_H_ -#define TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_H_ - -#include -#include - -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow_text/core/kernels/sentence_breaking_utils.h" - -namespace tensorflow { -namespace text { - -class Token { - public: - enum BreakLevel { - NO_BREAK = 0, // No separation between tokens. - SPACE_BREAK = 1, // Tokens separated by space. - LINE_BREAK = 2, // Tokens separated by line break. - SENTENCE_BREAK = 3, // Tokens separated by sentence break. - PARAGRAPH_BREAK = 4, // Tokens separated by paragraph break. - SECTION_BREAK = 10, // Tokens separated by section break. - CHAPTER_BREAK = 20, // Tokens separated by chapter break. - }; - - // Bitmask for properties of the token text. - enum TextProperty { - NONE = 0x00, - - // Token is ill-formed if: - // - // All tokens in a paragraph are marked as ill-formed if it has too few - // non-punctuation tokens in a paragraph (currently, a heading must have - // at least 2 tokens, and a non-heading must have at least 8). - // - // All tokens in a paragraph are marked as ill-formed if it lacks terminal - // sentence ending punctuation(e.g.: . ! ? …) or an emoticon (e.g.: ':)', - // ':D'). - // Exception: If a paragraph ends in an introductory punctuation - // character (','':' ';'), we say that it is an introductory paragraph. - // If it is followed by a "simple" HTML list (one whose list items have - // no substructure, such as embedded tables), then we keep both the - // introductory paragraph and the entire list. If not, we keep the - // introductory paragraph if it is followed by a well-formed paragraph. - // - // All tokens in a paragraph are marked as ill-formed if it contains the - // copyright sign (C in a circle) as this usually indicates a copyright - // notice, and is therefore effectively boilerplate. - ILL_FORMED = 0x01, - - // Indicates that the token is a part of the page title ( tag) or - // a heading (<hN> tag). - TITLE = 0x40, - HEADING = 0x02, - - // Text style. Determined from HTML tags only (<b>, etc), not from CSS. - BOLD = 0x04, - ITALIC = 0x08, - UNDERLINED = 0x10, - - // Indicates that the token is a part of a list. Currently set only for - // "simple" HTML lists (have no embedded paragraph boundaries) that are - // preceded by an introductory paragraph (ends in colon or a few other - // characters). - LIST = 0x20, - - // Token is an emoticon. - EMOTICON = 0x80, - - // Token was identified by Lexer as an acronym. Lexer identifies period-, - // hyphen-, and space-separated acronyms: "U.S.", "U-S", and "U S". - // Lexer normalizes all three to "US", but the token.word field - // normalizes only space-separated acronyms. - ACRONYM = 0x100, - - // Indicates that the token (or part of the token) is a covered by at - // least one hyperlink. More information of the hyperlink is stored in the - // first token covered by the hyperlink. - HYPERLINK = 0x200, - }; - - Token(const tstring &word, uint32 start, uint32 end, BreakLevel break_level, - TextProperty text_properties) - : word_(word), - start_(start), - end_(end), - break_level_(break_level), - text_properties_(text_properties) {} - - const tstring &word() const { return word_; } - const uint32 start() const { return start_; } - const uint32 end() const { return end_; } - const BreakLevel break_level() const { return break_level_; } - const TextProperty text_properties() const { return text_properties_; } - - private: - const tstring &word_; - uint32 start_; - uint32 end_; - BreakLevel break_level_; - TextProperty text_properties_; -}; - -class Document { - public: - // Does NOT take ownership of 'tokens'. - Document(std::vector<Token> *tokens) : tokens_(tokens) {} - - void AddToken(const tstring &word, uint32 start, uint32 end, - Token::BreakLevel break_level, - Token::TextProperty text_properties) { - tokens_->emplace_back(word, start, end, break_level, text_properties); - } - - const std::vector<Token> &tokens() const { return *tokens_; } - - private: - // not owned - std::vector<Token> *tokens_; -}; - -struct SentenceFragment { - int start; - int limit; - - enum Property { - TERMINAL_PUNC = 0x0001, // ends with terminal punctuation - MULTIPLE_TERMINAL_PUNC = 0x0002, // e.g.: She said what?! - HAS_CLOSE_PAREN = 0x0004, // e.g.: Mushrooms (they're fungi!!) - HAS_SENTENTIAL_CLOSE_PAREN = 0x0008, // e.g.: (Mushrooms are fungi!) - }; - // A mask of the above listed properties. - uint32 properties = 0; - int terminal_punc_token = -1; -}; - -// Utility class for splitting documents into a list of sentence fragments. -class SentenceFragmenter { - public: - // Constructs a fragmenter to process a specific part of a document. - SentenceFragmenter(const Document *document, UnicodeUtil *util) - : document_(document), util_(util) {} - - // Finds sentence fragments in the [start_, limit_) range of the associated - // document. - absl::Status FindFragments(std::vector<SentenceFragment> *result); - - private: - // State for matching a fragment-boundary regexp against a token sequence. - // The regexp is: terminal_punc+ close_punc*. - class FragmentBoundaryMatch; - - // Matches a fragment-boundary regexp against the tokens starting at - // 'i_start'. Returns the longest match found; will be non-empty as long as - // 'i_start' was not already at the end of the associated token range. - absl::Status FindNextFragmentBoundary(int i_start, - FragmentBoundaryMatch *result) const; - - // Updates 'latest_open_paren_is_sentential_' for the tokens in the given - // fragment. - absl::Status UpdateLatestOpenParenForFragment(int i_start, int i_end); - - // Populates a sentence fragment with the tokens from 'i_start' to the end - // of the given FragmentBoundaryMatch. - absl::Status FillInFragmentFields(int i_start, - const FragmentBoundaryMatch &match, - SentenceFragment *fragment) const; - - // Returns the adjusted first terminal punctuation index in a - // FragmentBoundaryMatch. - absl::Status GetAdjustedFirstTerminalPuncIndex( - const FragmentBoundaryMatch &match, int *result) const; - - // Returns true iff a FragmentBoundaryMatch has an "unattachable" terminal - // punctuation mark. - absl::Status HasUnattachableTerminalPunc(const FragmentBoundaryMatch &match, - bool *result) const; - - // Returns true iff a FragmentBoundaryMatch has a close paren in its closing - // punctuation. - absl::Status HasCloseParen(const FragmentBoundaryMatch &match, - bool *result) const; - - // Whether the latest open paren seen so far appears to be sentence-initial. - // See UpdateLatestOpenParenForFragment() in the .cc file for details. - bool latest_open_paren_is_sentential_ = false; - - const Document *document_ = nullptr; // not owned - UnicodeUtil *util_ = nullptr; // not owned - - // TODO(thuang513): DISALLOW_COPY_AND_ASSIGN(SentenceFragmenter); -}; +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_H_ -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/sentence_fragmenter.h" -#endif // TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_H_ diff --git a/tensorflow_text/core/kernels/sentence_fragmenter_v2.cc b/tensorflow_text/core/kernels/sentence_fragmenter_v2.cc deleted file mode 100644 index d917106c2..000000000 --- a/tensorflow_text/core/kernels/sentence_fragmenter_v2.cc +++ /dev/null @@ -1,706 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/sentence_fragmenter_v2.h" - -#include <string> - -#include "absl/status/status.h" -#include "absl/strings/match.h" -#include "absl/strings/string_view.h" -#include "icu4c/source/common/unicode/uchar.h" -#include "icu4c/source/common/unicode/utf8.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { -namespace text { - -void ConsumeOneUChar(const absl::string_view& input, UChar32* result, - int* offset) { - const char* source = input.data(); - - int input_length = input.length(); - U8_NEXT_OR_FFFD(source, *offset, input_length, *result); -} - -bool IsTerminalPunc(const absl::string_view& input, int* offset) { - *offset = 0; - bool is_ellipsis = IsEllipsis(input, offset); - if (is_ellipsis) return true; - - *offset = 0; - UChar32 char_value; - ConsumeOneUChar(input, &char_value, offset); - - // These are unicode characters that should be considered in this category but - // are not covered by any of the ICU properties. - switch (char_value) { - case 0x055C: // Armenian exclamation mark - case 0x055E: // Armenian question mark - case 0x17d4: // Khmer sign khan - case 0x037E: // Greek question mark - case 0x2026: // ellipsis - return true; - } - - USentenceBreak sb_property = static_cast<USentenceBreak>( - u_getIntPropertyValue(char_value, UCHAR_SENTENCE_BREAK)); - return sb_property == U_SB_ATERM || sb_property == U_SB_STERM; -} - -bool IsClosePunc(const absl::string_view& input, int* offset) { - *offset = 0; - - if (absl::StartsWith(input, "''")) { - *offset += absl::string_view("''").length(); - return true; - } - - UChar32 char_value; - ConsumeOneUChar(input, &char_value, offset); - - // These are unicode characters that should be considered in this category but - // are not covered by any of the ICU properties. - switch (char_value) { - case '>': - case ']': - case '`': - case 64831: // Ornate right parenthesis - case 65282: // fullwidth quotation mark - case 65287: // fullwidth apostrophe - return true; - } - - ULineBreak lb_property = static_cast<ULineBreak>( - u_getIntPropertyValue(char_value, UCHAR_LINE_BREAK)); - - return lb_property == U_LB_CLOSE_PUNCTUATION || - lb_property == U_LB_CLOSE_PARENTHESIS || lb_property == U_LB_QUOTATION; -} - -bool IsOpenParen(const absl::string_view& input) { - int offset = 0; - UChar32 char_value; - ConsumeOneUChar(input, &char_value, &offset); - - // These are unicode characters that should be considered in this category but - // are not covered by any of the ICU properties. - switch (char_value) { - case '<': - case 64830: // Ornate left parenthesis - return true; - } - - ULineBreak lb_property = static_cast<ULineBreak>( - u_getIntPropertyValue(char_value, UCHAR_LINE_BREAK)); - return lb_property == U_LB_OPEN_PUNCTUATION; -} - -bool IsCloseParen(const absl::string_view& input) { - int offset = 0; - - UChar32 char_value; - ConsumeOneUChar(input, &char_value, &offset); - - // These are unicode characters that should be considered in this category but - // are not covered by any of the ICU properties. - switch (char_value) { - case '>': - case 64831: // Ornate right parenthesis - return true; - } - - ULineBreak lb_property = static_cast<ULineBreak>( - u_getIntPropertyValue(char_value, UCHAR_LINE_BREAK)); - return lb_property == U_LB_CLOSE_PUNCTUATION || - lb_property == U_LB_CLOSE_PARENTHESIS; -} - -bool IsPunctuationWord(const absl::string_view& input) { - int offset = 0; - UChar32 char_value; - ConsumeOneUChar(input, &char_value, &offset); - - // These are unicode characters that should be considered in this category but - // are not covered by any of the ICU properties. - switch (char_value) { - case '`': - case '<': - case '>': - case '~': - case 5741: - return true; - } - - return u_ispunct(char_value) || u_hasBinaryProperty(char_value, UCHAR_DASH) || - u_hasBinaryProperty(char_value, UCHAR_HYPHEN); -} - -bool IsEllipsis(const absl::string_view& input, int* offset) { - *offset = 0; - if (absl::StartsWith(input, "...")) { - *offset += absl::string_view("...").length(); - return true; - } - - const UChar32 kEllipsisCharValue = 0x2026; - UChar32 char_value; - ConsumeOneUChar(input, &char_value, offset); - - return char_value == kEllipsisCharValue; -} - -inline bool IsAcronymComponent(const absl::string_view& input, int index) { - return (input.data()[index] >= 'A' && input.data()[index] <= 'Z') && - input.data()[index + 1] == '.'; -} - -bool IsPeriodSeparatedAcronym(const absl::string_view& input, int* offset) { - bool result = false; - - for (int i = 0; i < static_cast<int>(input.length()) - 1; i += 2) { - if (IsAcronymComponent(input, i)) { - *offset = i + 2; - if (*offset > 2) { - result = true; - } - } else { - break; - } - } - return result; -} - -bool IsEmoticon(const absl::string_view& input, int* offset) { - *offset = 0; - static std::vector<std::string> emoticon_list = {":(:)", - ":)", - ":(", - ":o)", - ":]", - ":3", - ":>", - "=]", - "=)", - ":}", - ":^)", - ":-D", - ":-)))))", - ":-))))", - ":-)))", - ":-))", - ":-)", - ">:[", - ":-(", - ":(", - ":-c", - ":c", - ":-<", - ":<", - ":-[", - ":[", - ":{", - ";(", - ":-||", - ":@", - ">:(", - ":'-(", - ":'(", - ":'-)", - ":')", - "D:<", - ">:O", - ":-O", - ":-o", - ":*", - ":-*", - ":^*", - ";-)", - ";)", - "*-)", - "*)", - ";-]", - ";]", - ";^)", - ":-,", - ">:P", - ":-P", - ":p", - "=p", - ":-p", - "=p", - ":P", - "=P", - ";p", - ";-p", - ";P", - ";-P", - ">:\\", - ">:/", - ":-/", - ":-.", - ":/", - ":\\", - "=/", - "=\\", - ":|", - ":-|", - ":$", - ":-#", - ":#", - "O:-)", - "0:-)", - "0:)", - "0;^)", - ">:)", - ">;)", - ">:-)", - "}:-)", - "}:)", - "3:-)", - ">_>^", - "^<_<", - "|;-)", - "|-O", - ":-J", - ":-&", - ":&", - "#-)", - "<3", - "8-)", - "^_^", - ":D", - ":-D", - "=D", - "^_^;;", - "O=)", - "}=)", - "B)", - "B-)", - "=|", - "-_-", - "o_o;", - "u_u", - ":-\\", - ":s", - ":S", - ":-s", - ":-S", - ";*", - ";-*" - "=(", - ">.<", - ">:-(", - ">:(", - ">=(", - ";_;", - "T_T", - "='(", - ">_<", - "D:", - ":o", - ":-o", - "=o", - "o.o", - ":O", - ":-O", - "=O", - "O.O", - "x_x", - "X-(", - "X(", - "X-o", - "X-O", - ":X)", - "(=^.^=)", - "(=^..^=)", - "=^_^=", - "-<@%", - ":(|)", - "(]:{", - "<\\3", - "~@~", - "8'(", - "XD", - "DX"}; - - for (int i = 0; i < static_cast<int>(emoticon_list.size()); ++i) { - if (absl::StartsWith(input, emoticon_list[i])) { - *offset = emoticon_list[i].length(); - return true; - } - } - return false; -} - -// Returns true iff the punctuation input can appear after a space in a -// sentence-terminal punctuation sequence. -bool SpaceAllowedBeforeChar(const absl::string_view& input) { - int offset = 0; - bool is_terminal_punc = IsTerminalPunc(input, &offset); - bool is_close_paren = IsCloseParen(input); - bool is_emoticon = IsEmoticon(input, &offset); - return is_terminal_punc || is_close_paren || is_emoticon; -} - -bool IsWhiteSpace(const absl::string_view& input) { - int offset = 0; - - if (absl::StartsWith(input, " ")) { - return true; - } else if (absl::StartsWith(input, "\n")) { - return true; - } else if (absl::StartsWith(input, " ")) { - return true; - } - - UChar32 char_value; - ConsumeOneUChar(input, &char_value, &offset); - - return u_isUWhiteSpace(char_value); -} - -// Follows the state transition for the slice at the given index. Returns true -// for success, or false if there was no valid transition. -bool FragmentBoundaryMatch::Advance(int index, absl::string_view slice) { - int temp_offset; - // By defualt offset is the next character. - int offset = 1; - bool no_transition = false; - bool is_terminal_punc = IsTerminalPunc(slice, &temp_offset); - if (is_terminal_punc) { - offset = temp_offset; - } - - bool is_ellipsis = IsEllipsis(slice, &temp_offset); - if (is_ellipsis) { - offset = temp_offset; - } - bool is_close_punc = IsClosePunc(slice, &temp_offset); - if (is_close_punc) { - offset = temp_offset; - } - bool is_acronym = IsPeriodSeparatedAcronym(slice, &temp_offset); - if (is_acronym) { - is_terminal_punc = false; - offset = temp_offset; - } - bool is_emoticon = IsEmoticon(slice, &temp_offset); - if (is_emoticon) { - is_terminal_punc = false; - offset = temp_offset; - } - - switch (state_) { - case INITIAL_STATE: - if (is_terminal_punc || is_acronym || is_emoticon) { - first_terminal_punc_index_ = index; - state_ = COLLECTING_TERMINAL_PUNC; - } - break; - case COLLECTING_TERMINAL_PUNC: - if (is_terminal_punc || is_emoticon) { - // Stay in COLLECTING_TERMINAL_PUNC state. - } else if (is_close_punc) { - first_close_punc_index_ = index; - state_ = COLLECTING_CLOSE_PUNC; - } else { - no_transition = true; - } - break; - case COLLECTING_CLOSE_PUNC: - if (is_close_punc || is_ellipsis || is_emoticon) { - // Stay in COLLECTING_CLOSE_PUNC state. We effectively ignore - // emoticons and ellipses and continue to accept closing punctuation - // after them. - } else { - no_transition = true; - } - break; - } - - if (no_transition) { - return false; - } else { - limit_index_ = index + offset; - if (state_ == COLLECTING_TERMINAL_PUNC) { - // We've gotten terminal punctuation, but no close punctuation yet. - first_close_punc_index_ = limit_index_; - } - return true; - } -} - -// Sets a property of a sentence fragment. -void SetFragmentProperty(SentenceFragment::Property property, - SentenceFragment* fragment) { - fragment->properties = fragment->properties | property; -} - -absl::Status SentenceFragmenterV2::FindFragments( - std::vector<SentenceFragment>* result) { - // Partition document into sentence fragments. - for (int i_start = 0; i_start < static_cast<int>(document_.size());) { - bool is_white_space = IsWhiteSpace(document_.substr(i_start)); - if (is_white_space) { - ++i_start; - continue; - } - - SentenceFragment fragment; - - // Match regexp for fragment boundary. - FragmentBoundaryMatch match = FindNextFragmentBoundary(i_start); - - // Update 'latest_open_paren_is_sentential_' for this fragment. - UpdateLatestOpenParenForFragment(i_start, match.limit_index()); - - // Add a new sentence fragment up to this boundary. - FillInFragmentFields(i_start, match, &fragment); - - result->push_back(std::move(fragment)); - i_start = match.limit_index(); - } - return absl::OkStatus(); -} - -// This method is essentially a control layer on top of a simple state machine -// that matches an end-of-fragment regexp. This method finds the next slice of -// text to feed to the state machine, and handles embedded whitespace. The main -// complexity is that a space may delimit end-of-match, or be embedded in the -// termination sequence. When we encounter a space, we record the match found so -// far, but also continue matching. We return the longer match if it succeeds, -// else fall back to the earlier one. Note that the lookahead can incur at most -// 2n cost. -// -// E.g., suppose we're given: x? !!!y. We encounter the space after "x?" and -// have to look ahead all the way to "y" before realizing that the longer match -// fails. We put a fragment boundary after "x?", and next time around, we again -// scan "!!!" looking for a fragment boundary. Since we failed to find one last -// time, we'll fail again this time and therefore continue past "y" to find the -// next boundary. We will not try to scan "!!!" a third time. - -FragmentBoundaryMatch SentenceFragmenterV2::FindNextFragmentBoundary( - int doc_index) const { - FragmentBoundaryMatch current_match; - FragmentBoundaryMatch previous_match; - - for (int i = doc_index; i < static_cast<int>(document_.size()); ++i) { - absl::string_view slice = document_.substr(i); - if (current_match.GotTerminalPunc() && i > doc_index) { - // Got terminal punctuation and a space delimiter, so match is valid. - bool space_allowed_before_char = SpaceAllowedBeforeChar(slice); - if (space_allowed_before_char) { - // Remember this match. Try to extend it. - previous_match = current_match; - } else { - // Stop here. We're not allowed to extend the match in this case. - break; - } - } - bool got_transition = current_match.Advance(i, slice); - if (!got_transition) { - if (previous_match.GotTerminalPunc()) { - // Extension failed. Return previous match. - return previous_match; - } else { - // Start matching again from scratch. - current_match.Reset(); - - // Reprocess current character since it might be terminal punctuation. - // No infinite loop, because can't be "no transition" from - // INITIAL_STATE. - --i; - } - } else { - i = current_match.limit_index() - 1; - } - } - return current_match; -} - -// Keep track of whether the latest open parenthesis seen so far appears to be -// sentence-initial. This is useful because if it is *non-sentence-initial*, -// then any terminal punctuation before the corresponding close paren is -// probably not a sentence boundary. Example: -// -// Mushrooms (they're fungi!!) are delicious. -// (Mushrooms are fungi!!) -// -// In the first case, the open paren is non-sentence-initial, and therefore -// the "!!)" is not a sentence boundary. In the second case, the open paren *is* -// sentence-initial, and so the "!!)" is a sentence boundary. -// -// Of course, we don't know true sentence boundaries, so we make the -// approximation that an open paren is sentence-initial iff it is -// fragment-initial. This will be wrong if the open paren occurs after terminal -// punctuation that turns out not to be a sentence boundary, e.g., -// "Yahoo! (known for search, etc.) blah", but this is not expected to happen -// often. -void SentenceFragmenterV2::UpdateLatestOpenParenForFragment(int i_start, - int i_end) { - for (int i = i_end; i > i_start; --i) { - absl::string_view slice = document_.substr(i); - if (slice.length() > 0 && IsOpenParen(slice)) { - // Make the approximation that this open paren is sentence-initial iff it - // is fragment-initial. - latest_open_paren_is_sentential_ = (i == i_start); - break; - } - } -} - -void SentenceFragmenterV2::FillInFragmentFields( - int i_start, const FragmentBoundaryMatch& match, - SentenceFragment* fragment) const { - // Set the fragment's boundaries. - fragment->start = i_start; - fragment->limit = match.limit_index(); - - // Set the fragment's properties. - if (match.GotTerminalPunc()) { - // TERMINAL_PUNC. - SetFragmentProperty(SentenceFragment::TERMINAL_PUNC, fragment); - int terminal_punc_index = GetAdjustedFirstTerminalPuncIndex(match); - - bool has_unattachable_terminal_punc = HasUnattachableTerminalPunc(match); - bool has_close_paren = HasCloseParen(match); - - fragment->terminal_punc_token = terminal_punc_index; - // MULTIPLE_TERMINAL_PUNC. - if (has_unattachable_terminal_punc) { - SetFragmentProperty(SentenceFragment::MULTIPLE_TERMINAL_PUNC, fragment); - } - - // HAS_CLOSE_PAREN & HAS_SENTENTIAL_CLOSE_PAREN. - if (has_close_paren) { - SetFragmentProperty(SentenceFragment::HAS_CLOSE_PAREN, fragment); - - if (latest_open_paren_is_sentential_) { - SetFragmentProperty(SentenceFragment::HAS_SENTENTIAL_CLOSE_PAREN, - fragment); - } - } - } -} - -// The standard first terminal punctuation index is just -// match.first_terminal_punc_index(). But if there is an ambiguous terminal -// punctuation mark (ellipsis) followed by an unambiguous one (.!?), then we -// treat the ellipsis as part of the sentence, and return the index of the first -// unambiguous punctuation mark after it. Example: -// -// He agreed...! -// -// We treat "!" as the first terminal punctuation mark; the ellipsis acts as -// left context. -int SentenceFragmenterV2::GetAdjustedFirstTerminalPuncIndex( - const FragmentBoundaryMatch& match) const { - // Get terminal punctuation span. - int i1 = match.first_terminal_punc_index(); - if (i1 < 0) { - return i1; - } - int i2 = match.first_close_punc_index(); - - for (int i = i2; i > i1; --i) { - absl::string_view slice = document_.substr(i); - int temp_offset = 0; - bool is_ellipsis = IsEllipsis(slice, &temp_offset); - bool is_emoticon = IsEmoticon(slice, &temp_offset); - if (is_ellipsis || is_emoticon) { - if (i == i2) { - // Ellipsis is last terminal punctuation mark. No adjustment. - return i1; - } else { - // Ellipsis is not the last terminal punctuation mark. Return the index - // of the terminal punctuation mark after it. - return i; // current character = i - 1 - } - } - } - // No ellipsis. - return i1; -} - -// Example of an an "unattachable" terminal punctuation mark: -// -// He agreed!? -// -// The "?" is "unattachable" in that it can't be part of the word "agreed" -// because of the intervening "!", and therefore strongly suggests this is a -// true sentence boundary. The terminal punctuation mark must be unambiguous -// (.!?), as ambiguous ones (ellipsis/emoticon) do not necessarily imply a -// sentence boundary. -bool SentenceFragmenterV2::HasUnattachableTerminalPunc( - const FragmentBoundaryMatch& match) const { - // Get terminal punctuation span. - int i1 = match.first_terminal_punc_index(); - if (i1 < 0) { - return false; - } - // Check where second and later punctuation marks start - absl::string_view start_slice = document_.substr(i1); - int temp_offset = 0; - bool is_ellipsis = IsEllipsis(start_slice, &temp_offset); - if (is_ellipsis) { - i1 += temp_offset - 1; - } - bool is_emoticon = IsEmoticon(start_slice, &temp_offset); - if (is_emoticon) { - i1 += temp_offset - 1; - } - - int i2 = match.first_close_punc_index(); - - // Iterate over the second and later punctuation marks. - for (int i = i1 + 1; i < i2; ++i) { - absl::string_view slice = document_.substr(i); - bool is_punctuation = IsPunctuationWord(slice); - is_ellipsis = IsEllipsis(slice, &temp_offset); - if (is_ellipsis) { - i += temp_offset - 1; - } - is_emoticon = IsEmoticon(slice, &temp_offset); - if (is_emoticon) { - i += temp_offset - 1; - } - if (is_punctuation && !is_ellipsis && !is_emoticon) { - // Found an unattachable, unambiguous terminal punctuation mark. - return true; - } - } - return false; -} - -bool SentenceFragmenterV2::HasCloseParen( - const FragmentBoundaryMatch& match) const { - // Get close punctuation span. - int i1 = match.first_close_punc_index(); - if (i1 < 0) { - return false; - } - int i2 = match.limit_index(); - - for (int i = i1; i < i2; ++i) { - absl::string_view slice = document_.substr(i); - if (IsCloseParen(slice)) { - return true; - } - } - return false; -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/sentence_fragmenter_v2.h b/tensorflow_text/core/kernels/sentence_fragmenter_v2.h index 6c06867eb..fec2ea0b3 100644 --- a/tensorflow_text/core/kernels/sentence_fragmenter_v2.h +++ b/tensorflow_text/core/kernels/sentence_fragmenter_v2.h @@ -12,189 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Updated version of sentence fragmenter and util functions to split up a -// document into sentence fragments. A sentence fragment is a string whose end -// is potentially an end-of-sentence. The original version of -// sentence_fragmenter operates on tokens and defines the start and end of -// fragments using token indices, while sentence_fragmenter_v2 operates on a -// string_view sliding window of the text and defines the start and end of a -// fragment based on the character offset. -// -// Example: -// -// Document text: -// John said, "I.B.M. went up 5 points today." -// -// SentenceFragments: -// (1) John said, "I.B.M. -// (2) went up 5 points today." -// -// Fragment boundaries are induced by punctuation and paragraph breaks. - -#ifndef TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_V2_H_ -#define TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_V2_H_ - -#include <vector> - -#include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "icu4c/source/common/unicode/utypes.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { -namespace text { - -// A class of utils for identifying certain classes and properties of unicode -// characters. These utils are included in the header for use in tests. - -// Returns true iff a string is terminal punctuation. -bool IsTerminalPunc(const absl::string_view& input, int* offset); - -// Returns true iff a string is close punctuation (close quote or close -// paren). -bool IsClosePunc(const absl::string_view& input, int* offset); - -// Returns true iff a string is an open paren. -bool IsOpenParen(const absl::string_view& input); - -// Returns true iff a string is a close paren. -bool IsCloseParen(const absl::string_view& input); - -// Returns true iff a word is made of punctuation characters only. -bool IsPunctuationWord(const absl::string_view& input); - -// Returns true iff a string is an ellipsis ("..."). -bool IsEllipsis(const absl::string_view& input, int* offset); - -// Returns true iff a string is a period separated acronym (ex: "A.B.C."). -bool IsPeriodSeparatedAcronym(const absl::string_view& input, int* offset); - -// Returns true iff a string is an emoticon (ex: ":-)"). -bool IsEmoticon(const absl::string_view& input, int* offset); - -bool SpaceAllowedBeforeChar(const absl::string_view& input); - -void ConsumeOneUChar(const absl::string_view& input, UChar32* result, - int* offset); - -// Returns true iff a string is white space. -bool IsWhiteSpace(const absl::string_view& input); - -class FragmentBoundaryMatch { - public: - FragmentBoundaryMatch() {} - - // Goes to initial state. - void Reset() { - state_ = INITIAL_STATE; - first_terminal_punc_index_ = -1; - first_close_punc_index_ = -1; - limit_index_ = -1; - } - - // Follows the state transition for the slice at - // the given index. Returns true for success, or - // false if there was no valid transition. - bool Advance(int index, absl::string_view slice); - - // Returns true iff we have matched at least one terminal punctuation - // character. - bool GotTerminalPunc() const { return first_terminal_punc_index_ >= 0; } - - // Field accessors. - int first_terminal_punc_index() const { return first_terminal_punc_index_; } - int first_close_punc_index() const { return first_close_punc_index_; } - int limit_index() const { return limit_index_; } - - // Match state. - enum MatchState { - INITIAL_STATE = 0, - COLLECTING_TERMINAL_PUNC, - COLLECTING_CLOSE_PUNC - }; - - MatchState state() const { return state_; } - - private: - MatchState state_ = INITIAL_STATE; - - // First terminal punctuation mark matched; may be an acronym. - // -1 for not found. - int first_terminal_punc_index_ = -1; - - // First closing punctuation mark matched. -1 for not found. - int first_close_punc_index_ = -1; - - // First character after the terminal sequence. - int limit_index_ = -1; -}; - -struct SentenceFragment { - int start; - int limit; - - enum Property { - TERMINAL_PUNC = 0x0001, // ends with terminal punctuation - MULTIPLE_TERMINAL_PUNC = 0x0002, // e.g.: She said what?! - HAS_CLOSE_PAREN = 0x0004, // e.g.: Mushrooms (they're fungi!!) - HAS_SENTENTIAL_CLOSE_PAREN = 0x0008, // e.g.: (Mushrooms are fungi!) - }; - // A mask of the above listed properties. - uint32 properties = 0; - int terminal_punc_token = -1; -}; - -// Utility class for splitting documents into a list of sentence fragments. -class SentenceFragmenterV2 { - public: - // Constructs a fragmenter to process a specific part of a document. - SentenceFragmenterV2(absl::string_view document) : document_(document) {} - - // Finds sentence fragments in the [start_, limit_) range of the associated - // document. - absl::Status FindFragments(std::vector<SentenceFragment>* result); - - private: - // State for matching a fragment-boundary regexp against a character sequence. - // The regexp is: terminal_punc+ close_punc*. - - // Matches a fragment-boundary regexp against a slice of the document starting - // at 'doc_index'. Returns the longest match found; will be non-empty as long - // as 'doc_index' was not already at the end of the associated document. - FragmentBoundaryMatch FindNextFragmentBoundary(int doc_index) const; - - // Updates 'latest_open_paren_is_sentential_' for the given - // fragment. - void UpdateLatestOpenParenForFragment(int i_start, int i_end); - - // Populates a sentence fragment with the text from 'i_start' to the end - // of the given FragmentBoundaryMatch. - void FillInFragmentFields(int i_start, const FragmentBoundaryMatch& match, - SentenceFragment* fragment) const; - - // Returns the adjusted first terminal punctuation index in a - // FragmentBoundaryMatch. - int GetAdjustedFirstTerminalPuncIndex( - const FragmentBoundaryMatch& match) const; - - // Returns true iff a FragmentBoundaryMatch has an "unattachable" terminal - // punctuation mark. - bool HasUnattachableTerminalPunc(const FragmentBoundaryMatch& match) const; - - // Returns true iff a FragmentBoundaryMatch has a close paren in its closing - // punctuation. - bool HasCloseParen(const FragmentBoundaryMatch& match) const; - - // Whether the latest open paren seen so far appears to be sentence-initial. - // See UpdateLatestOpenParenForFragment() in the .cc file for details. - bool latest_open_paren_is_sentential_ = false; - - absl::string_view document_ = {}; // not owned - - // TODO(thuang513): DISALLOW_COPY_AND_ASSIGN(SentenceFragmenter); -}; +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_V2_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_V2_H_ -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/sentence_fragmenter_v2.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_V2_H_ diff --git a/tensorflow_text/core/kernels/sentence_fragmenter_v2_kernel.h b/tensorflow_text/core/kernels/sentence_fragmenter_v2_kernel.h index d36c7e9c5..fd0e910a2 100644 --- a/tensorflow_text/core/kernels/sentence_fragmenter_v2_kernel.h +++ b/tensorflow_text/core/kernels/sentence_fragmenter_v2_kernel.h @@ -15,19 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_V2_KERNEL_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_V2_KERNEL_H_ -#include "tensorflow/lite/kernels/shim/tf_op_shim.h" -#include "tensorflow_text/core/kernels/sentence_fragmenter_v2_kernel_template.h" - -namespace tensorflow { -namespace text { - -class SentenceFragmenterV2OpKernel - : public tflite::shim::TfOpKernel<SentenceFragmenterV2Op> { - public: - using TfOpKernel::TfOpKernel; -}; - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/sentence_fragmenter_v2_kernel.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_V2_KERNEL_H_ diff --git a/tensorflow_text/core/kernels/sentence_fragmenter_v2_kernel_template.h b/tensorflow_text/core/kernels/sentence_fragmenter_v2_kernel_template.h index ea7be5862..954d03ac8 100644 --- a/tensorflow_text/core/kernels/sentence_fragmenter_v2_kernel_template.h +++ b/tensorflow_text/core/kernels/sentence_fragmenter_v2_kernel_template.h @@ -15,150 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_V2_KERNEL_TEMPLATE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_V2_KERNEL_TEMPLATE_H_ - -#include <iostream> -#include <vector> - -#include "absl/status/status.h" -#include "tensorflow/lite/kernels/shim/op_kernel.h" -#include "tensorflow/lite/kernels/shim/shape.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow_text/core/kernels/sentence_fragmenter_v2.h" - -namespace tensorflow { -namespace text { - -template <tflite::shim::Runtime Rt> -class SentenceFragmenterV2Op - : public tflite::shim::OpKernelShim<SentenceFragmenterV2Op, Rt> { - private: - enum Inputs { - kInputValues = 0 - }; - enum Outputs { - kFragmentStart = 0, - kFragmentEnd, - kFragmentProperties, - kTerminalPuncToken, - kOutputRowLengths - }; - - using typename tflite::shim::OpKernelShim<SentenceFragmenterV2Op, - Rt>::InitContext; - using typename tflite::shim::OpKernelShim<SentenceFragmenterV2Op, - Rt>::InvokeContext; - using typename tflite::shim::OpKernelShim<SentenceFragmenterV2Op, - Rt>::ShapeInferenceContext; - - public: - SentenceFragmenterV2Op() = default; - static constexpr char kOpName[] = "SentenceFragmentsV2"; - static constexpr char kDoc[] = R"doc( - Splits a string into sentence fragments - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector<std::string> Attrs() { return {}; } - - // Inputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector<std::string> Inputs(); - - // Outputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector<std::string> Outputs(); - - // Initializes the op - absl::Status Init(InitContext* context) { return absl::OkStatus(); } - - // Runs the operation - absl::Status Invoke(InvokeContext* context); - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* c); -}; - -template <tflite::shim::Runtime Rt> -std::vector<std::string> SentenceFragmenterV2Op<Rt>::Inputs() { - return {"doc: string"}; -} - -template <tflite::shim::Runtime Rt> -std::vector<std::string> SentenceFragmenterV2Op<Rt>::Outputs() { - return {"fragment_start: int64", "fragment_end: int64", - "fragment_properties: int64", "terminal_punc_token: int64", - "output_row_lengths: int64"}; -} - -template <tflite::shim::Runtime Rt> -absl::Status SentenceFragmenterV2Op<Rt>::ShapeInference( - ShapeInferenceContext* c) { - using tflite::shim::Shape; - const auto rank_1_shape = Shape({Shape::kUnknownDim}); - - SH_ASSIGN_OR_RETURN(const Shape& input_values_shape, - c->GetInputShape(kInputValues)); - if (!input_values_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be rank 1: ", input_values_shape.ToString())); - } - - SH_RETURN_IF_ERROR(c->SetOutputShape(kFragmentStart, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kFragmentEnd, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kFragmentProperties, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kTerminalPuncToken, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputRowLengths, rank_1_shape)); - - return absl::OkStatus(); -} - -template <tflite::shim::Runtime Rt> -absl::Status SentenceFragmenterV2Op<Rt>::Invoke(InvokeContext* context) { - // Inputs - SH_ASSIGN_OR_RETURN(const auto input_values, context->GetInput(kInputValues)); - const auto document = input_values->template As<tensorflow::tstring, 1>(); - - // Outputs - std::vector<int64> fragment_start; - std::vector<int64> fragment_end; - std::vector<int64> fragment_properties; - std::vector<int64> terminal_punc_token; - std::vector<int64> output_row_lengths; - - // Iterate through all the documents and find fragments. - for (int i = 0; i < document.Dim(0); ++i) { - // Find fragments. - SentenceFragmenterV2 fragmenter(document(i)); - std::vector<SentenceFragment> frags; - - SH_RETURN_IF_ERROR(fragmenter.FindFragments(&frags)); - - for (const auto& f : frags) { - fragment_start.push_back(f.start); - fragment_end.push_back(f.limit); - fragment_properties.push_back(f.properties); - terminal_punc_token.push_back(f.terminal_punc_token); - } - output_row_lengths.push_back(frags.size()); - } - - // Allocate output & fill output tensors. - SH_RETURN_IF_ERROR(this->template FillOutputTensor<int64_t, int64_t>( - fragment_start, kFragmentStart, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor<int64_t, int64_t>( - fragment_end, kFragmentEnd, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor<int64_t, int64_t>( - fragment_properties, kFragmentProperties, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor<int64_t, int64_t>( - terminal_punc_token, kTerminalPuncToken, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor<int64_t, int64_t>( - output_row_lengths, kOutputRowLengths, context)); - - return absl::OkStatus(); -} - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/sentence_fragmenter_v2_kernel_template.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_V2_KERNEL_TEMPLATE_H_ diff --git a/tensorflow_text/core/kernels/sentence_fragmenter_v2_test.cc b/tensorflow_text/core/kernels/sentence_fragmenter_v2_test.cc deleted file mode 100644 index 87cd49265..000000000 --- a/tensorflow_text/core/kernels/sentence_fragmenter_v2_test.cc +++ /dev/null @@ -1,1092 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/sentence_fragmenter_v2.h" - -#include <string> -#include <vector> - -#include <gtest/gtest.h> -#include "absl/strings/string_view.h" -#include "icu4c/source/common/unicode/uchar.h" -#include "icu4c/source/common/unicode/umachine.h" -#include "icu4c/source/common/unicode/unistr.h" - -namespace tensorflow { -namespace text { -namespace { - -class SentenceBreakingUtilsParamTest : public ::testing::TestWithParam<UChar> { - protected: - std::string StringFromUnicodeChar(UChar32 input) { - std::string result; - icu::UnicodeString test_unicode_string(input); - test_unicode_string.toUTF8String(result); - return result; - } -}; - -class SentenceBreakingUtilsStringParamTest - : public ::testing::TestWithParam<const char*> {}; - -class IsTerminalPuncParamTest : public SentenceBreakingUtilsParamTest {}; - -class IsTerminalPuncTest : public ::testing::Test {}; - -const UChar is_terminal_punc_test_cases[] = { - 0x055C, // Armenian exclamation mark - 0x055E, // Armenian question mark - 0x0589, // Armenian full stop - 0x061F, // Arabic question mark - 0x06D4, // Arabic full stop - 0x0700, // Syriabc end of paragraph - 0x0701, // Syriac supralinear full stop - 0x0702, // Syriac sublinear full stop - 0x1362, // Ethiopic full stop - 0x1367, // Ethiopic question mark - 0x1368, // Ethiopic paragraph separator - 0x104A, // Myanmar sign little section - 0x104B, // Myanmar sign section - 0x166E, // Canadian syllabics full stop - 0x17d4, // Khmer sign khan - 0x1803, // Mongolian full stop - 0x1809, // Mongolian Manchu full stop - 0x1944, // Limbu exclamation mark - 0x1945, // Limbu question mark - 0x203C, // double exclamation mark - 0x203D, // interrobang - 0x2047, // double question mark - 0x2048, // question exclamation mark - 0x2049, // exclamation question mark - 0x3002, // ideographic full stop - 0x037E, // Greek question mark - 0xFE52, // small full stop - 0xFE56, // small question mark - 0xFE57, // small exclamation mark - 0xFF01, // fullwidth exclamation mark - 0xFF0E, // fullwidth full stop - 0xFF1F, // fullwidth question mark - 0xFF61, // halfwidth ideographic full stop - 0x2026, // ellipsis - 0x0964, - 0x0965, // Devanagari danda..Devanagari double -}; - -TEST_P(IsTerminalPuncParamTest, IsTerminalPunc) { - std::string test_string = StringFromUnicodeChar(GetParam()); - int offset; - EXPECT_TRUE(IsTerminalPunc(test_string, &offset)); -} - -INSTANTIATE_TEST_SUITE_P(IsTerminalPuncTest, IsTerminalPuncParamTest, - ::testing::ValuesIn(is_terminal_punc_test_cases)); - -TEST_F(IsTerminalPuncTest, IsMultiCharEllipseTerminalPunc) { - std::string test_string = "..."; - int offset; - EXPECT_TRUE(IsTerminalPunc(test_string, &offset)); -} - -TEST_F(IsTerminalPuncTest, TestMultiUnicodeChars) { - std::string test_string = "never gonna let you decode"; - int offset; - EXPECT_FALSE(IsTerminalPunc(test_string, &offset)); -} - -struct ClosePuncOffsetPairs { - const UChar close_punc; - const int offset; -}; - -class SentenceBreakingUtilsClosePuncPairParamTest - : public ::testing::TestWithParam<ClosePuncOffsetPairs> { - protected: - std::string StringFromUnicodeChar(UChar32 input) { - std::string result; - icu::UnicodeString test_unicode_string(input); - test_unicode_string.toUTF8String(result); - return result; - } -}; - -class ClosePuncParamTest : public SentenceBreakingUtilsClosePuncPairParamTest { -}; - -const ClosePuncOffsetPairs close_punc_test_cases[] = { - {0x29, 1}, - {0x5D, 1}, - {0x3E, 1}, - {0x7D, 1}, - {0x207E, 3}, // superscript right parenthesis - {0x208E, 3}, // subscript right parenthesis - {0x27E7, 3}, // mathematical right white square bracket - {0x27E9, 3}, // mathematical right angle bracket - {0x27EB, 3}, // mathematical right double angle bracket - {0x2984, 3}, // right white curly bracket - {0x2986, 3}, // right white parenthesis - {0x2988, 3}, // Z notation right image bracket - {0x298A, 3}, // Z notation right binding bracket - {0x298C, 3}, // right square bracket with underbar - {0x298E, 3}, // right square bracket with tick in top corner - {0x2990, 3}, // right square bracket with tick in bottom corner - {0x2992, 3}, // right angle bracket with dot - {0x2994, 3}, // right arc greater-than bracket - {0x2996, 3}, // double right arc less-than bracket - {0x2998, 3}, // right black tortoise shell bracket - {0x29D9, 3}, // right wiggly fence - {0x29DB, 3}, // right double wiggly fence - {0x29FD, 3}, // right-pointing curved angle bracket - {0x3009, 3}, // CJK right angle bracket - {0x300B, 3}, // CJK right double angle bracket - {0x3011, 3}, // CJK right black lenticular bracket - {0x3015, 3}, // CJK right tortoise shell bracket - {0x3017, 3}, // CJK right white lenticular bracket - {0x3019, 3}, // CJK right white tortoise shell bracket - {0x301B, 3}, // CJK right white square bracket - {0xFD3F, 3}, // Ornate right parenthesis - {0xFE5A, 3}, // small right parenthesis - {0xFE5C, 3}, // small right curly bracket - {0xFF09, 3}, // fullwidth right parenthesis - {0xFF3D, 3}, // fullwidth right square bracket - {0xFF5D, 3}, // fullwidth right curly bracket - {0x27, 1}, - {0x60, 1}, - {0x22, 1}, - {0xFF07, 3}, // fullwidth apostrophe - {0xFF02, 3}, // fullwidth quotation mark - {0x2019, 3}, // right single quotation mark (English, others) - {0x201D, 3}, // right double quotation mark (English, others) - {0x2018, 3}, // left single quotation mark (Czech, German, Slovak) - {0x201C, 3}, // left double quotation mark (Czech, German, Slovak) - {0x203A, 3}, // single right-pointing angle quotation mark (French, others) - {0x00BB, 2}, // right-pointing double angle quotation mark (French, others) - {0x2039, 3}, // single left-pointing angle quotation mark (Slovenian, - // others) - {0x00AB, 2}, // left-pointing double angle quotation mark (Slovenian, - // others) - {0x300D, 3}, // right corner bracket (East Asian languages) - {0xfe42, 3}, // presentation form for vertical right corner bracket - {0xFF63, 3}, // halfwidth right corner bracket (East Asian languages) - {0x300F, 3}, // right white corner bracket (East Asian languages) - {0xfe44, 3}, // presentation form for vertical right white corner bracket - {0x301F, 3}, // low double prime quotation mark (East Asian languages) - {0x301E, 3} // close double prime (East Asian languages written - // horizontally) -}; - -TEST_P(ClosePuncParamTest, IsClosePunc) { - ClosePuncOffsetPairs test_punc = GetParam(); - std::string test_string = StringFromUnicodeChar(test_punc.close_punc); - int expected_offset = test_punc.offset; - int offset; - EXPECT_TRUE(IsClosePunc(test_string, &offset)); - EXPECT_EQ(offset, expected_offset); -} - -INSTANTIATE_TEST_SUITE_P(IsClosePuncParamTest, ClosePuncParamTest, - ::testing::ValuesIn(close_punc_test_cases)); - -class OpenParenParamTest : public SentenceBreakingUtilsParamTest {}; - -const UChar open_paren_test_cases[] = { - '(', '[', '<', '{', - 0x207D, // superscript left parenthesis - 0x208D, // subscript left parenthesis - 0x27E6, // mathematical left white square bracket - 0x27E8, // mathematical left angle bracket - 0x27EA, // mathematical left double angle bracket - 0x2983, // left white curly bracket - 0x2985, // left white parenthesis - 0x2987, // Z notation left image bracket - 0x2989, // Z notation left binding bracket - 0x298B, // left square bracket with underbar - 0x298D, // left square bracket with tick in top corner - 0x298F, // left square bracket with tick in bottom corner - 0x2991, // left angle bracket with dot - 0x2993, // left arc less-than bracket - 0x2995, // double left arc greater-than bracket - 0x2997, // left black tortoise shell bracket - 0x29D8, // left wiggly fence - 0x29DA, // left double wiggly fence - 0x29FC, // left-pointing curved angle bracket - 0x3008, // CJK left angle bracket - 0x300A, // CJK left double angle bracket - 0x3010, // CJK left black lenticular bracket - 0x3014, // CJK left tortoise shell bracket - 0x3016, // CJK left white lenticular bracket - 0x3018, // CJK left white tortoise shell bracket - 0x301A, // CJK left white square bracket - 0xFD3E, // Ornate left parenthesis - 0xFE59, // small left parenthesis - 0xFE5B, // small left curly bracket - 0xFF08, // fullwidth left parenthesis - 0xFF3B, // fullwidth left square bracket - 0xFF5B, // fullwidth left curly bracket -}; - -TEST_P(OpenParenParamTest, IsOpenParen) { - std::string test_string = StringFromUnicodeChar(GetParam()); - EXPECT_TRUE(IsOpenParen(test_string)); -} - -INSTANTIATE_TEST_SUITE_P(IsOpenParenParamTest, OpenParenParamTest, - ::testing::ValuesIn(open_paren_test_cases)); - -class CloseParenParamTest : public SentenceBreakingUtilsParamTest {}; - -const UChar close_paren_test_cases[] = { - ')', ']', '>', '}', - 0x207E, // superscript right parenthesis - 0x208E, // subscript right parenthesis - 0x27E7, // mathematical right white square bracket - 0x27E9, // mathematical right angle bracket - 0x27EB, // mathematical right double angle bracket - 0x2984, // right white curly bracket - 0x2986, // right white parenthesis - 0x2988, // Z notation right image bracket - 0x298A, // Z notation right binding bracket - 0x298C, // right square bracket with underbar - 0x298E, // right square bracket with tick in top corner - 0x2990, // right square bracket with tick in bottom corner - 0x2992, // right angle bracket with dot - 0x2994, // right arc greater-than bracket - 0x2996, // double right arc less-than bracket - 0x2998, // right black tortoise shell bracket - 0x29D9, // right wiggly fence - 0x29DB, // right double wiggly fence - 0x29FD, // right-pointing curved angle bracket - 0x3009, // CJK right angle bracket - 0x300B, // CJK right double angle bracket - 0x3011, // CJK right black lenticular bracket - 0x3015, // CJK right tortoise shell bracket - 0x3017, // CJK right white lenticular bracket - 0x3019, // CJK right white tortoise shell bracket - 0x301B, // CJK right white square bracket - 0xFD3F, // Ornate right parenthesis - 0xFE5A, // small right parenthesis - 0xFE5C, // small right curly bracket - 0xFF09, // fullwidth right parenthesis - 0xFF3D, // fullwidth right square bracket - 0xFF5D, // fullwidth right curly bracket -}; - -TEST_P(CloseParenParamTest, IsCloseParen) { - std::string test_string = StringFromUnicodeChar(GetParam()); - EXPECT_TRUE(IsCloseParen(test_string)); -} - -INSTANTIATE_TEST_SUITE_P(IsCloseParenParamTest, CloseParenParamTest, - ::testing::ValuesIn(close_paren_test_cases)); - -class IsPunctuationWordParamTest : public SentenceBreakingUtilsParamTest {}; - -const UChar punc_word_test_cases[] = { - '(', '[', '<', '{', - 0x207D, // superscript left parenthesis - 0x208D, // subscript left parenthesis - 0x27E6, // mathematical left white square bracket - 0x27E8, // mathematical left angle bracket - 0x27EA, // mathematical left double angle bracket - 0x2983, // left white curly bracket - 0x2985, // left white parenthesis - 0x2987, // Z notation left image bracket - 0x2989, // Z notation left binding bracket - 0x298B, // left square bracket with underbar - 0x298D, // left square bracket with tick in top corner - 0x298F, // left square bracket with tick in bottom corner - 0x2991, // left angle bracket with dot - 0x2993, // left arc less-than bracket - 0x2995, // double left arc greater-than bracket - 0x2997, // left black tortoise shell bracket - 0x29D8, // left wiggly fence - 0x29DA, // left double wiggly fence - 0x29FC, // left-pointing curved angle bracket - 0x3008, // CJK left angle bracket - 0x300A, // CJK left double angle bracket - 0x3010, // CJK left black lenticular bracket - 0x3014, // CJK left tortoise shell bracket - 0x3016, // CJK left white lenticular bracket - 0x3018, // CJK left white tortoise shell bracket - 0x301A, // CJK left white square bracket - 0xFD3E, // Ornate left parenthesis - 0xFE59, // small left parenthesis - 0xFE5B, // small left curly bracket - 0xFF08, // fullwidth left parenthesis - 0xFF3B, // fullwidth left square bracket - 0xFF5B, // fullwidth left curly bracket - '"', '\'', '`', - 0xFF07, // fullwidth apostrophe - 0xFF02, // fullwidth quotation mark - 0x2018, // left single quotation mark (English, others) - 0x201C, // left double quotation mark (English, others) - 0x201B, // single high-reveresed-9 quotation mark (PropList.txt) - 0x201A, // single low-9 quotation mark (Czech, German, Slovak) - 0x201E, // double low-9 quotation mark (Czech, German, Slovak) - 0x201F, // double high-reversed-9 quotation mark (PropList.txt) - 0x2019, // right single quotation mark (Danish, Finnish, Swedish, Norw.) - 0x201D, // right double quotation mark (Danish, Finnish, Swedish, Norw.) - 0x2039, // single left-pointing angle quotation mark (French, others) - 0x00AB, // left-pointing double angle quotation mark (French, others) - 0x203A, // single right-pointing angle quotation mark (Slovenian, others) - 0x00BB, // right-pointing double angle quotation mark (Slovenian, others) - 0x300C, // left corner bracket (East Asian languages) - 0xFE41, // presentation form for vertical left corner bracket - 0xFF62, // halfwidth left corner bracket (East Asian languages) - 0x300E, // left white corner bracket (East Asian languages) - 0xFE43, // presentation form for vertical left white corner bracket - 0x301D, // reversed double prime quotation mark (East Asian langs, horiz.) - ')', ']', '>', '}', - 0x207E, // superscript right parenthesis - 0x208E, // subscript right parenthesis - 0x27E7, // mathematical right white square bracket - 0x27E9, // mathematical right angle bracket - 0x27EB, // mathematical right double angle bracket - 0x2984, // right white curly bracket - 0x2986, // right white parenthesis - 0x2988, // Z notation right image bracket - 0x298A, // Z notation right binding bracket - 0x298C, // right square bracket with underbar - 0x298E, // right square bracket with tick in top corner - 0x2990, // right square bracket with tick in bottom corner - 0x2992, // right angle bracket with dot - 0x2994, // right arc greater-than bracket - 0x2996, // double right arc less-than bracket - 0x2998, // right black tortoise shell bracket - 0x29D9, // right wiggly fence - 0x29DB, // right double wiggly fence - 0x29FD, // right-pointing curved angle bracket - 0x3009, // CJK right angle bracket - 0x300B, // CJK right double angle bracket - 0x3011, // CJK right black lenticular bracket - 0x3015, // CJK right tortoise shell bracket - 0x3017, // CJK right white lenticular bracket - 0x3019, // CJK right white tortoise shell bracket - 0x301B, // CJK right white square bracket - 0xFD3F, // Ornate right parenthesis - 0xFE5A, // small right parenthesis - 0xFE5C, // small right curly bracket - 0xFF09, // fullwidth right parenthesis - 0xFF3D, // fullwidth right square bracket - 0xFF5D, // fullwidth right curly bracket - '\'', '"', '`', - 0xFF07, // fullwidth apostrophe - 0xFF02, // fullwidth quotation mark - 0x2019, // right single quotation mark (English, others) - 0x201D, // right double quotation mark (English, others) - 0x2018, // left single quotation mark (Czech, German, Slovak) - 0x201C, // left double quotation mark (Czech, German, Slovak) - 0x203A, // single right-pointing angle quotation mark (French, others) - 0x00BB, // right-pointing double angle quotation mark (French, others) - 0x2039, // single left-pointing angle quotation mark (Slovenian, others) - 0x00AB, // left-pointing double angle quotation mark (Slovenian, others) - 0x300D, // right corner bracket (East Asian languages) - 0xfe42, // presentation form for vertical right corner bracket - 0xFF63, // halfwidth right corner bracket (East Asian languages) - 0x300F, // right white corner bracket (East Asian languages) - 0xfe44, // presentation form for vertical right white corner bracket - 0x301F, // low double prime quotation mark (East Asian languages) - 0x301E, // close double prime (East Asian languages written horizontally) - 0x00A1, // Spanish inverted exclamation mark - 0x00BF, // Spanish inverted question mark - '.', '!', '?', - 0x055C, // Armenian exclamation mark - 0x055E, // Armenian question mark - 0x0589, // Armenian full stop - 0x061F, // Arabic question mark - 0x06D4, // Arabic full stop - 0x0700, // Syriac end of paragraph - 0x0701, // Syriac supralinear full stop - 0x0702, // Syriac sublinear full stop - 0x0964, // Devanagari danda..Devanagari double danda - 0x0965, - 0x1362, // Ethiopic full stop - 0x1367, // Ethiopic question mark - 0x1368, // Ethiopic paragraph separator - 0x104A, // Myanmar sign little section - 0x104B, // Myanmar sign section - 0x166E, // Canadian syllabics full stop - 0x17d4, // Khmer sign khan - 0x1803, // Mongolian full stop - 0x1809, // Mongolian Manchu full stop - 0x1944, // Limbu exclamation mark - 0x1945, // Limbu question mark - 0x203C, // double exclamation mark - 0x203D, // interrobang - 0x2047, // double question mark - 0x2048, // question exclamation mark - 0x2049, // exclamation question mark - 0x3002, // ideographic full stop - 0x037E, // Greek question mark - 0xFE52, // small full stop - 0xFE56, // small question mark - 0xFE57, // small exclamation mark - 0xFF01, // fullwidth exclamation mark - 0xFF0E, // fullwidth full stop - 0xFF1F, // fullwidth question mark - 0xFF61, // halfwidth ideographic full stop - 0x2026, // ellipsis - 0x30fb, // Katakana middle dot - 0xff65, // halfwidth Katakana middle dot - 0x2040, // character tie - '-', '~', - 0x058a, // Armenian hyphen - 0x1806, // Mongolian todo soft hyphen - 0x2010, // hyphen..horizontal bar - 0x2011, 0x2012, 0x2013, 0x2014, 0x2015, - 0x2053, // swung dash -- from Table 6-3 of Unicode book - 0x207b, // superscript minus - 0x208b, // subscript minus - 0x2212, // minus sign - 0x301c, // wave dash - 0x3030, // wavy dash - 0xfe31, // presentation form for vertical em dash..en dash - 0xfe32, - 0xfe58, // small em dash - 0xfe63, // small hyphen-minus - 0xff0d, // fullwidth hyphen-minus - ',', ':', ';', - 0x00b7, // middle dot - 0x0387, // Greek ano teleia - 0x05c3, // Hebrew punctuation sof pasuq - 0x060c, // Arabic comma - 0x061b, // Arabic semicolon - 0x066b, // Arabic decimal separator - 0x066c, // Arabic thousands separator - 0x0703, // Syriac contraction and others - 0x0704, 0x0705, 0x0706, 0x0707, 0x0708, 0x0709, 0x70a, - 0x070c, // Syric harklean metobelus - 0x0e5a, // Thai character angkhankhu - 0x0e5b, // Thai character khomut - 0x0f08, // Tibetan mark sbrul shad - 0x0f0d, // Tibetan mark shad..Tibetan mark rgya gram shad - 0x0f0e, 0x0f0f, 0x0f10, 0x0f11, 0x0f12, - 0x1361, // Ethiopic wordspace - 0x1363, // other Ethiopic chars - 0x1364, 0x1365, 0x1366, - 0x166d, // Canadian syllabics chi sign - 0x16eb, // Runic single punctuation..Runic cross punctuation - 0x16ed, - 0x17d5, // Khmer sign camnuc pii huuh and other - 0x17d6, - 0x17da, // Khmer sign koomut - 0x1802, // Mongolian comma - 0x1804, // Mongolian four dots and other - 0x1805, - 0x1808, // Mongolian manchu comma - 0x3001, // ideographic comma - 0xfe50, // small comma and others - 0xfe51, - 0xfe54, // small semicolon and other - 0xfe55, - 0xff0c, // fullwidth comma - 0xff0e, // fullwidth stop..fullwidth solidus - 0xff0f, - 0xff1a, // fullwidth colon..fullwidth semicolon - 0xff1b, - 0xff64, // halfwidth ideographic comma - 0x2016, // double vertical line - 0x2032, 0x2033, - 0x2034, // prime..triple prime - 0xfe61, // small asterisk - 0xfe68, // small reverse solidus - 0xff3c, // fullwidth reverse solidus -}; - -TEST_P(IsPunctuationWordParamTest, IsPunctuation) { - std::string test_string = StringFromUnicodeChar(GetParam()); - EXPECT_TRUE(IsPunctuationWord(test_string)); -} - -INSTANTIATE_TEST_SUITE_P(IsPuncWordParamTest, IsPunctuationWordParamTest, - ::testing::ValuesIn(punc_word_test_cases)); - -class IsEllipsisTest : public ::testing::Test {}; - -TEST_F(IsEllipsisTest, IsEllipsis) { - int offset; - EXPECT_TRUE(IsEllipsis("...", &offset)); - EXPECT_EQ(offset, 3); - EXPECT_TRUE(IsEllipsis("…", &offset)); - EXPECT_EQ(offset, 3); - EXPECT_FALSE(IsEllipsis("@", &offset)); - EXPECT_EQ(offset, 1); -} - -class IsWhiteSpaceTest : public ::testing::Test {}; - -TEST_F(IsWhiteSpaceTest, IsWhiteSpace) { - EXPECT_TRUE(IsWhiteSpace(" ")); - - EXPECT_TRUE(IsWhiteSpace("\n")); - - EXPECT_TRUE(IsWhiteSpace(" ")); - - EXPECT_FALSE(IsWhiteSpace("@")); - - EXPECT_FALSE(IsWhiteSpace("w")); -} - -class IsAcronymTest : public ::testing::Test {}; - -TEST_F(IsAcronymTest, IsAcronym) { - int offset = 0; - EXPECT_TRUE(IsPeriodSeparatedAcronym("U.S.", &offset)); - EXPECT_EQ(offset, 4); - - offset = 0; - EXPECT_TRUE(IsPeriodSeparatedAcronym("E.A.T.", &offset)); - EXPECT_EQ(offset, 6); - - offset = 0; - EXPECT_TRUE(IsPeriodSeparatedAcronym("A.B.C.D.E.F.", &offset)); - EXPECT_EQ(offset, 12); - - offset = 0; - EXPECT_FALSE(IsPeriodSeparatedAcronym("X.", &offset)); - - EXPECT_FALSE(IsPeriodSeparatedAcronym("US", &offset)); - - EXPECT_FALSE(IsPeriodSeparatedAcronym("U-S", &offset)); -} - -class EmoticonParamTest : public SentenceBreakingUtilsStringParamTest {}; - -static const char* const emoticon_test_cases[] = {":(:)", - ":)", - ":(", - ":o)", - ":]", - ":3", - ":>", - "=]", - "=)", - ":}", - ":^)", - ":-D", - ":-)))))", - ":-))))", - ":-)))", - ":-))", - ":-)", - ">:[", - ":-(", - ":(", - ":-c", - ":c", - ":-<", - ":<", - ":-[", - ":[", - ":{", - ";(", - ":-||", - ":@", - ">:(", - ":'-(", - ":'(", - ":'-)", - ":')", - "D:<", - ">:O", - ":-O", - ":-o", - ":*", - ":-*", - ":^*", - ";-)", - ";)", - "*-)", - "*)", - ";-]", - ";]", - ";^)", - ":-,", - ">:P", - ":-P", - ":p", - "=p", - ":-p", - "=p", - ":P", - "=P", - ";p", - ";-p", - ";P", - ";-P", - ">:\\", - ">:/", - ":-/", - ":-.", - ":/", - ":\\", - "=/", - "=\\", - ":|", - ":-|", - ":$", - ":-#", - ":#", - "O:-)", - "0:-)", - "0:)", - "0;^)", - ">:)", - ">;)", - ">:-)", - "}:-)", - "}:)", - "3:-)", - ">_>^", - "^<_<", - "|;-)", - "|-O", - ":-J", - ":-&", - ":&", - "#-)", - "<3", - "8-)", - "^_^", - ":D", - ":-D", - "=D", - "^_^;;", - "O=)", - "}=)", - "B)", - "B-)", - "=|", - "-_-", - "o_o;", - "u_u", - ":-\\", - ":s", - ":S", - ":-s", - ":-S", - ";*", - ";-*" - "=(", - ">.<", - ">:-(", - ">:(", - ">=(", - ";_;", - "T_T", - "='(", - ">_<", - "D:", - ":o", - ":-o", - "=o", - "o.o", - ":O", - ":-O", - "=O", - "O.O", - "x_x", - "X-(", - "X(", - "X-o", - "X-O", - ":X)", - "(=^.^=)", - "(=^..^=)", - "=^_^=", - "-<@%", - ":(|)", - "(]:{", - "<\\3", - "~@~", - "8'(", - "XD", - "DX"}; - -TEST_P(EmoticonParamTest, IsEmoticon) { - int offset = 0; - EXPECT_TRUE(IsEmoticon(GetParam(), &offset)); -} - -INSTANTIATE_TEST_SUITE_P(IsEmoticonParamTest, EmoticonParamTest, - ::testing::ValuesIn(emoticon_test_cases)); - -class IsEmoticonTest : public ::testing::Test {}; - -TEST_F(IsEmoticonTest, IsEmoticon) { - int offset = 0; - - EXPECT_TRUE(IsEmoticon(">:-(", &offset)); - - EXPECT_FALSE(IsEmoticon("w", &offset)); - - EXPECT_FALSE(IsEmoticon(":", &offset)); -} - -TEST(SentenceFragmenterTest, Basic) { - // 1 - // 012345678901234 - string test_input = "Hello. Foo bar!"; - SentenceFragmenterV2 fragmenter(test_input); - std::vector<SentenceFragment> fragments; - EXPECT_TRUE(fragmenter.FindFragments(&fragments).ok()); - EXPECT_EQ(fragments[0].start, 0); - EXPECT_EQ(fragments[0].limit, 6); - EXPECT_EQ(fragments[1].start, 7); - EXPECT_EQ(fragments[1].limit, 15); -} - -TEST(SentenceFragmenterTest, BasicEllipsis) { - // 1 - // 012345678901234 - string test_input = "Hello...foo bar"; - SentenceFragmenterV2 fragmenter(test_input); - std::vector<SentenceFragment> fragments; - EXPECT_TRUE(fragmenter.FindFragments(&fragments).ok()); - - EXPECT_EQ(fragments[0].start, 0); - EXPECT_EQ(fragments[0].limit, 8); - EXPECT_EQ(fragments[1].start, 8); - EXPECT_EQ(fragments[1].limit, 15); -} - -TEST(SentenceFragmenterTest, Parentheses) { - // 1 2 - // 012345678901234567890123456789 - string test_input = "Hello (who are you...) foo bar"; - SentenceFragmenterV2 fragmenter(test_input); - std::vector<SentenceFragment> fragments; - EXPECT_TRUE(fragmenter.FindFragments(&fragments).ok()); - EXPECT_EQ(fragments[0].start, 0); - EXPECT_EQ(fragments[0].limit, 22); - EXPECT_EQ(fragments[1].start, 23); - EXPECT_EQ(fragments[1].limit, 30); -} - -TEST(SentenceFragmenterTest, MidFragmentParentheses) { - // 1 2 - // 012345678901234567890123456789 - string test_input = "Hello (who are you) world? Foo bar"; - SentenceFragmenterV2 fragmenter(test_input); - std::vector<SentenceFragment> fragments; - EXPECT_TRUE(fragmenter.FindFragments(&fragments).ok()); - EXPECT_EQ(fragments[0].start, 0); - EXPECT_EQ(fragments[0].limit, 26); - EXPECT_EQ(fragments[1].start, 27); - EXPECT_EQ(fragments[1].limit, 34); -} - -TEST(SentenceFragmenterTest, PunctuationAfterParentheses) { - // 1 2 - // 01234567890123456789012345678 - string test_input = "Hello (who are you)? Foo bar!"; - SentenceFragmenterV2 fragmenter(test_input); - std::vector<SentenceFragment> fragments; - EXPECT_TRUE(fragmenter.FindFragments(&fragments).ok()); - EXPECT_EQ(fragments[0].start, 0); - EXPECT_EQ(fragments[0].limit, 20); - EXPECT_EQ(fragments[1].start, 21); - EXPECT_EQ(fragments[1].limit, 29); -} - -TEST(SentenceFragmenterTest, ManyFinalPunctuations) { - // 1 2 - // 0123456789012345678901234 - string test_input = "Hello!!!!! Who are you??"; - SentenceFragmenterV2 fragmenter(test_input); - std::vector<SentenceFragment> fragments; - EXPECT_TRUE(fragmenter.FindFragments(&fragments).ok()); - EXPECT_EQ(fragments[0].start, 0); - EXPECT_EQ(fragments[0].limit, 10); - EXPECT_EQ(fragments[1].start, 11); - EXPECT_EQ(fragments[1].limit, 24); -} - -TEST(SentenceFragmenterTest, NewLine) { - // 1 2 3 - // 012345678901234567890 1 23456 7 89012 3 45678 - string test_input = "Who let the dogs out?\r\nWho?\r\nWho?\r\nWho?"; - SentenceFragmenterV2 fragmenter(test_input); - std::vector<SentenceFragment> fragments; - EXPECT_TRUE(fragmenter.FindFragments(&fragments).ok()); - EXPECT_EQ(fragments[0].start, 0); - EXPECT_EQ(fragments[0].limit, 21); - EXPECT_EQ(fragments[1].start, 23); - EXPECT_EQ(fragments[1].limit, 27); - EXPECT_EQ(fragments[2].start, 29); - EXPECT_EQ(fragments[2].limit, 33); - EXPECT_EQ(fragments[3].start, 35); - EXPECT_EQ(fragments[3].limit, 39); -} - -TEST(SentenceFragmenterTest, WhiteSpaceInPunctuation) { - // 1 2 - // 0123456789012345678901234 - string test_input = "Hello?? !!! Who are you??"; - SentenceFragmenterV2 fragmenter(test_input); - std::vector<SentenceFragment> fragments; - EXPECT_TRUE(fragmenter.FindFragments(&fragments).ok()); - EXPECT_EQ(fragments[0].start, 0); - EXPECT_EQ(fragments[0].limit, 7); - EXPECT_EQ(fragments[1].start, 8); - EXPECT_EQ(fragments[1].limit, 11); - EXPECT_EQ(fragments[2].start, 12); - EXPECT_EQ(fragments[2].limit, 25); -} - -} // namespace - -TEST(FragmentBoundaryMatchTest, NoStateChange) { - FragmentBoundaryMatch f; - // || - // 012345678901234 - string test_input = "Hello...foo bar"; - int index = 0; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_FALSE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), -1); - EXPECT_EQ(f.first_close_punc_index(), -1); - EXPECT_EQ(f.limit_index(), 1); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::INITIAL_STATE); -} - -TEST(FragmentBoundaryMatchTest, BasicEllipsis) { - FragmentBoundaryMatch f; - // | | - // 0123456789 - string test_input = "...foo bar"; - int index = 0; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 3); - EXPECT_EQ(f.limit_index(), 3); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_TERMINAL_PUNC); -} - -TEST(FragmentBoundaryMatchTest, BasicPeriod) { - FragmentBoundaryMatch f; - // || - // 0123456789 - string test_input = ". Foo bar"; - int index = 0; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 1); - EXPECT_EQ(f.limit_index(), 1); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_TERMINAL_PUNC); -} - -TEST(FragmentBoundaryMatchTest, BasicAcronym) { - FragmentBoundaryMatch f; - // | | - // 0123456789 - string test_input = "A.B. xyz"; - int index = 0; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 4); - EXPECT_EQ(f.limit_index(), 4); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_TERMINAL_PUNC); -} - -TEST(FragmentBoundaryMatchTest, LongerAcronym) { - FragmentBoundaryMatch f; - // | | - // 0123456789 - string test_input = "I.B.M. yo"; - int index = 0; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 6); - EXPECT_EQ(f.limit_index(), 6); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_TERMINAL_PUNC); -} - -TEST(FragmentBoundaryMatchTest, Emoticon) { - FragmentBoundaryMatch f; - // | | - // 0123456789012 - string test_input = ">:-( hello..."; - int index = 0; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 4); - EXPECT_EQ(f.limit_index(), 4); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_TERMINAL_PUNC); -} - -TEST(FragmentBoundaryMatchTest, ParensWithEllipsis) { - FragmentBoundaryMatch f; - // || - // 0123456789012345 - string test_input = ".foo...) foo bar"; - int index = 0; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 1); - EXPECT_EQ(f.limit_index(), 1); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_TERMINAL_PUNC); -} - -TEST(FragmentBoundaryMatchTest, ClosingParenWithEllipsis) { - FragmentBoundaryMatch f; - // | | - // 012345678901 - string test_input = "...) foo bar"; - int index = 0; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 3); - EXPECT_EQ(f.limit_index(), 3); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_TERMINAL_PUNC); -} - -TEST(FragmentBoundaryMatchTest, BeginAndEndParenWithEllipsis) { - FragmentBoundaryMatch f; - // || - // 0123456789012 - string test_input = "(...) foo bar"; - int index = 0; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_FALSE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), -1); - EXPECT_EQ(f.first_close_punc_index(), -1); - EXPECT_EQ(f.limit_index(), 1); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::INITIAL_STATE); - - // | | - // 0123456789012 - test_input = "...) foo bar"; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 3); - EXPECT_EQ(f.limit_index(), 3); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_TERMINAL_PUNC); -} - -TEST(FragmentBoundaryMatchTest, AcronymInSentence) { - FragmentBoundaryMatch f; - // | | - // 0123456789012 - string test_input = "U.S. don't be surprised."; - int index = 0; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 4); - EXPECT_EQ(f.limit_index(), 4); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_TERMINAL_PUNC); -} - -TEST(FragmentBoundaryMatchTest, HelloWithEllipsis) { - FragmentBoundaryMatch f; - // || - // 01234567890 - string test_input = "o...foo bar"; - int index = 0; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_FALSE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), -1); - EXPECT_EQ(f.first_close_punc_index(), -1); - EXPECT_EQ(f.limit_index(), 1); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::INITIAL_STATE); - - // | | - // 0123456789 - test_input = "...foo bar"; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 3); - EXPECT_EQ(f.limit_index(), 3); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_TERMINAL_PUNC); -} - -TEST(FragmentBoundaryMatchTest, ThreeStatesWithClosigParen) { - FragmentBoundaryMatch f; - // || - // 0123456789012 - string test_input = "w...) foo bar"; - int index = 0; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_FALSE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), -1); - EXPECT_EQ(f.first_close_punc_index(), -1); - EXPECT_EQ(f.limit_index(), 1); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::INITIAL_STATE); - - // | | - // 0123456789012 - test_input = "...) foo bar"; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 3); - EXPECT_EQ(f.limit_index(), 3); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_TERMINAL_PUNC); - - // || - // 0123456789012 - test_input = ") foo bar"; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 0); - EXPECT_EQ(f.limit_index(), 1); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_CLOSE_PUNC); - - // || - // 0123456789012 - test_input = " foo bar"; - EXPECT_FALSE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 0); - EXPECT_EQ(f.limit_index(), 1); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_CLOSE_PUNC); -} - -TEST(FragmentBoundaryMatchTest, NoTransition) { - FragmentBoundaryMatch f; - // | | - // 0123456789012 - string test_input = "...foo bar"; - int index = 0; - EXPECT_TRUE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 3); - EXPECT_EQ(f.limit_index(), 3); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_TERMINAL_PUNC); - - // || - // 0123456789012 - test_input = "foo bar"; - EXPECT_FALSE(f.Advance(index, test_input)); - EXPECT_TRUE(f.GotTerminalPunc()); - EXPECT_EQ(f.first_terminal_punc_index(), 0); - EXPECT_EQ(f.first_close_punc_index(), 3); - EXPECT_EQ(f.limit_index(), 3); - EXPECT_EQ(f.state(), FragmentBoundaryMatch::COLLECTING_TERMINAL_PUNC); -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/sentence_fragmenter_v2_tflite.cc b/tensorflow_text/core/kernels/sentence_fragmenter_v2_tflite.cc deleted file mode 100644 index 47cb94d1f..000000000 --- a/tensorflow_text/core/kernels/sentence_fragmenter_v2_tflite.cc +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/sentence_fragmenter_v2_tflite.h" - -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/shim/tflite_op_shim.h" -#include "tensorflow_text/core/kernels/sentence_fragmenter_v2_kernel_template.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -extern "C" void AddSentenceFragmenterV2(tflite::MutableOpResolver* resolver) { - tflite::shim::TfLiteOpKernel< - tensorflow::text::SentenceFragmenterV2Op>::Add(resolver); -} - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow_text/core/kernels/sentence_fragmenter_v2_tflite.h b/tensorflow_text/core/kernels/sentence_fragmenter_v2_tflite.h index 7f0694eb2..091283d9d 100644 --- a/tensorflow_text/core/kernels/sentence_fragmenter_v2_tflite.h +++ b/tensorflow_text/core/kernels/sentence_fragmenter_v2_tflite.h @@ -15,19 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_V2_TFLITE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_V2_TFLITE_H_ -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/mutable_op_resolver.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -extern "C" void AddSentenceFragmenterV2(::tflite::MutableOpResolver* resolver); - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite +#include "tensorflow/core/kernels/text/sentence_fragmenter_v2_tflite.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SENTENCE_FRAGMENTER_V2_TFLITE_H_ diff --git a/tensorflow_text/core/kernels/sentencepiece/BUILD b/tensorflow_text/core/kernels/sentencepiece/BUILD index 78e1eb8fd..32364e31b 100644 --- a/tensorflow_text/core/kernels/sentencepiece/BUILD +++ b/tensorflow_text/core/kernels/sentencepiece/BUILD @@ -1,366 +1,181 @@ -# Memorymappable, WASM compilable, implementation of the encoder. -# +"""Sentencepiece kernels for tf.text ops. +All implementation files moved to //third_party/tensorflow/core/kernels/text/sentencepiece. +""" -load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@rules_cc//cc:cc_library.bzl", "cc_library") -load("@rules_cc//cc:cc_test.bzl", "cc_test") -load("//tensorflow_text:tftext.bzl", "tf_cc_library", "tflite_cc_library") licenses(["notice"]) -# Visibility rules -package(default_visibility = ["//visibility:public"]) +package( + default_applicable_licenses = ["//tensorflow_text:license"], + default_compatible_with = ["//buildenv/target:non_prod"], + default_visibility = ["//visibility:public"], +) + +# Aliases to relocated targets + +ALIAS_NAMES = [ + "testdata", + "config_fbs", + "sp_headers", + "config", + "encoder_config", + "decoder_config", + "double_array_trie_test", + "sentencepiece_tokenizer_kernel", + "sentencepiece_detokenizer_kernel", + "sentencepiece_tokenizer_tflite", + "sentencepiece_detokenizer_tflite", + "optimized_encoder_test", + "optimized_decoder_test", + "macos", + "apple", +] -filegroup( +alias( name = "testdata", - srcs = [ - "//tensorflow_text:python/ops/test_data/fast_sentencepiece.model", - ], + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:testdata", ) -filegroup( +alias( name = "config_fbs", - srcs = ["config.fbs"], + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:config_fbs", ) -filegroup( +alias( name = "sp_headers", - srcs = [ - "py_tflite_registerer.h", - ], - visibility = ["//visibility:public"], + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:sp_headers", ) -flatbuffer_cc_library( +alias( name = "config", - srcs = [ - "config.fbs", - ], + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:config", ) -flatbuffer_cc_library( +alias( name = "encoder_config", - srcs = [ - "encoder_config.fbs", - ], - includes = [":config_fbs"], + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:encoder_config", ) -flatbuffer_cc_library( +alias( name = "decoder_config", - srcs = [ - "decoder_config.fbs", - ], - includes = [":config_fbs"], + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:decoder_config", +) + +alias( + name = "double_array_trie_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:double_array_trie_test", +) + +alias( + name = "sentencepiece_tokenizer_kernel", + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:sentencepiece_tokenizer_kernel", +) + +alias( + name = "sentencepiece_detokenizer_kernel", + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:sentencepiece_detokenizer_kernel", +) + +alias( + name = "sentencepiece_tokenizer_tflite", + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:sentencepiece_tokenizer_tflite", +) + +alias( + name = "sentencepiece_detokenizer_tflite", + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:sentencepiece_detokenizer_tflite", ) +alias( + name = "optimized_encoder_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:optimized_encoder_test", +) + +alias( + name = "optimized_decoder_test", + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:optimized_decoder_test", +) + +alias( + name = "macos", + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:macos", +) + +alias( + name = "apple", + actual = "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:apple", +) + +LIBRARY_HEADERS = { + "utils": "utils.h", + "double_array_trie": "double_array_trie.h", + "double_array_trie_builder": "double_array_trie_builder.h", + "sentencepiece_constants": "sentencepiece_constants.h", + "model_converter": "model_converter.h", + "optimized_encoder": "optimized_encoder.h", + "optimized_decoder": "optimized_decoder.h", + "sentencepiece_tokenizer_h": "sentencepiece_tokenizer.h", + "sentencepiece_detokenizer_h": "sentencepiece_detokenizer.h", + "py_tflite_registerer": "py_tflite_registerer.h", +} + cc_library( name = "utils", - srcs = [ - ], - hdrs = [ - "utils.h", - ], + hdrs = ["utils.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:utils"], ) cc_library( name = "double_array_trie", - srcs = [ - ], - hdrs = [ - "double_array_trie.h", - ], - deps = [ - ":config", - ":utils", - ], + hdrs = ["double_array_trie.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:double_array_trie"], ) cc_library( name = "double_array_trie_builder", - srcs = [ - "double_array_trie_builder.cc", - ], - hdrs = [ - "double_array_trie_builder.h", - ], - deps = [ - ":config", - ":utils", - "@darts_clone", - ], -) - -cc_test( - name = "double_array_trie_test", - srcs = [ - "double_array_trie_test.cc", - ], - deps = [ - ":double_array_trie", - ":double_array_trie_builder", - ":encoder_config", - "@com_google_googletest//:gtest_main", - ], + hdrs = ["double_array_trie_builder.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:double_array_trie_builder"], ) cc_library( name = "sentencepiece_constants", - srcs = [], hdrs = ["sentencepiece_constants.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:sentencepiece_constants"], ) cc_library( name = "model_converter", - srcs = [ - "model_converter.cc", - ], - hdrs = [ - "model_converter.h", - ], - deps = [ - ":config", - ":decoder_config", - ":double_array_trie_builder", - ":encoder_config", - ":sentencepiece_constants", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_sentencepiece//:sentencepiece_model_cc_proto", - ], + hdrs = ["model_converter.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:model_converter"], ) cc_library( name = "optimized_encoder", - srcs = [ - "optimized_encoder.cc", - ], - hdrs = [ - "optimized_encoder.h", - ], - deps = [ - ":config", - ":double_array_trie", - ":encoder_config", - ], + hdrs = ["optimized_encoder.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:optimized_encoder"], ) cc_library( name = "optimized_decoder", - srcs = [ - "optimized_decoder.cc", - ], - hdrs = [ - "optimized_decoder.h", - ], - deps = [ - "config", - ":decoder_config", - ":double_array_trie", - ], + hdrs = ["optimized_decoder.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:optimized_decoder"], ) cc_library( name = "sentencepiece_tokenizer_h", - hdrs = [ - "sentencepiece_tokenizer.h", - ], + hdrs = ["sentencepiece_tokenizer.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:sentencepiece_tokenizer_h"], ) cc_library( name = "sentencepiece_detokenizer_h", - hdrs = [ - "sentencepiece_detokenizer.h", - ], -) - -tf_cc_library( - name = "sentencepiece_tokenizer_kernel", - srcs = ["sentencepiece_tokenizer_kernel.cc"], - tf_deps = [ - # tf:lib tensorflow dep, - # tf:framework tensorflow dep, - ], - deps = [ - ":optimized_encoder", - ":sentencepiece_tokenizer_h", - ], -) - -tf_cc_library( - name = "sentencepiece_detokenizer_kernel", - srcs = ["sentencepiece_detokenizer_kernel.cc"], - tf_deps = [ - # tf:lib tensorflow dep, - # tf:framework tensorflow dep, - ], - deps = [ - ":optimized_decoder", - ":sentencepiece_detokenizer_h", - # tf/protobuf:error_codes_proto_impl_cc tensorflow dep, - ], -) - -tflite_cc_library( - name = "sentencepiece_tokenizer_tflite", - srcs = ["sentencepiece_tokenizer_tflite.cc"], - deps = - [ - ":optimized_encoder", - ":sentencepiece_tokenizer_h", - "@flatbuffers", - # lite:framework tensorflow dep, - # lite:string_util tensorflow dep, - # lite/c:common tensorflow dep, - # lite/kernels:kernel_util tensorflow dep, - # lite/kernels/internal:tensor tensorflow dep, - ], + hdrs = ["sentencepiece_detokenizer.h"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:sentencepiece_detokenizer_h"], ) -tflite_cc_library( - name = "sentencepiece_detokenizer_tflite", - srcs = ["sentencepiece_detokenizer_tflite.cc"], - deps = - [ - ":optimized_decoder", - ":sentencepiece_detokenizer_h", - "@flatbuffers", - # lite:framework tensorflow dep, - # lite:string_util tensorflow dep, - # lite/c:common tensorflow dep, - # lite/kernels:kernel_util tensorflow dep, - # lite/kernels/internal:tensor tensorflow dep, - ], -) - -cc_test( - name = "optimized_encoder_test", - srcs = [ - "optimized_encoder_test.cc", - ], - data = [ - ":testdata", - ], - deps = [ - ":double_array_trie_builder", - ":encoder_config", - ":model_converter", - ":optimized_encoder", - "//file/base:path", - "//file/localfile", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", - "@com_google_sentencepiece//:sentencepiece_cc_proto", - "@com_google_sentencepiece//:sentencepiece_processor", - # tf:lib tensorflow dep, - # lite/kernels:test_util tensorflow dep, - ], -) - -cc_test( - name = "optimized_decoder_test", - srcs = [ - "optimized_decoder_test.cc", - ], - data = [ - ":testdata", - ], - deps = [ - ":model_converter", - ":optimized_decoder", - "//file/base:path", - "//file/localfile", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/strings:str_format", - "@com_google_sentencepiece//:sentencepiece_cc_proto", - "@com_google_sentencepiece//:sentencepiece_processor", - # tf:lib tensorflow dep, - # lite/kernels:test_util tensorflow dep, - ], -) - -tflite_cc_library( +cc_library( name = "py_tflite_registerer", - srcs = ["py_tflite_registerer.cc"], hdrs = ["py_tflite_registerer.h"], - deps = [ - ":sentencepiece_detokenizer_tflite", - ":sentencepiece_tokenizer_tflite", - # lite:framework tensorflow dep, - # lite/kernels:builtin_ops tensorflow dep, - ], - alwayslink = 1, -) - -config_setting( - name = "armeabi_v7a_and_fastbuild", - constraint_values = ["//third_party/bazel_platforms/cpu:armv7"], - values = { - "compilation_mode": "fastbuild", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "armeabi_v7a_and_dbg", - constraint_values = ["//third_party/bazel_platforms/cpu:armv7"], - values = { - "compilation_mode": "dbg", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "android", - values = {"crosstool_top": "//external:android/crosstool"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "macos_i386", - values = { - "apple_platform_type": "macos", - "cpu": "darwin", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "macos_x86_64", - values = { - "apple_platform_type": "macos", - "cpu": "darwin_x86_64", - }, - visibility = ["//visibility:public"], -) - -alias( - name = "macos", - actual = select({ - ":macos_i386": ":macos_i386", - ":macos_x86_64": ":macos_x86_64", - "//conditions:default": ":macos_i386", # Arbitrarily chosen from above. - }), - visibility = ["//visibility:public"], -) - -config_setting( - name = "ios", - values = { - "crosstool_top": "@bazel_tools//tools/cpp:toolchain", - "apple_platform_type": "ios", - }, - visibility = ["//visibility:public"], -) - -alias( - name = "apple", - actual = select({ - ":macos": ":macos", - ":ios": ":ios", - "//conditions:default": ":ios", # Arbitrarily chosen from above. - }), - visibility = ["//visibility:public"], + deps = ["@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:py_tflite_registerer"], ) diff --git a/tensorflow_text/core/kernels/sentencepiece/config.fbs b/tensorflow_text/core/kernels/sentencepiece/config.fbs deleted file mode 100644 index 4b9cc9a81..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/config.fbs +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -namespace tensorflow.text.sentencepiece; - -table Trie { - nodes: [uint32]; -} - - -enum EncoderVersion: byte { - SENTENCE_PIECE = 0, -} diff --git a/tensorflow_text/core/kernels/sentencepiece/decoder_config.fbs b/tensorflow_text/core/kernels/sentencepiece/decoder_config.fbs deleted file mode 100644 index bcb787dbd..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/decoder_config.fbs +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -include "config.fbs"; - -namespace tensorflow.text.sentencepiece; - - -table DecoderConfig { - version: EncoderVersion = SENTENCE_PIECE; - - // The offset for encoding, usually used when codes with low codes are reserved - // for some special needs. - encoding_offset: int32; - - // A vector of strings that represent sentencepieces. - decode_pieces: [string]; - - // TODO(mgubin): Currently is not populated, haven't seen any Sentencepiece - // model with a denormalizer. - denormalized_prefixes: Trie; - denormalized_replacements: [byte]; - - // During encoding a dummy prefix (a whitespace) can be added to the input string, - // if this flag is true, this prefix will be removed. - remove_dummy_prefix: bool; - -} - - -root_type DecoderConfig; diff --git a/tensorflow_text/core/kernels/sentencepiece/double_array_trie.h b/tensorflow_text/core/kernels/sentencepiece/double_array_trie.h index 0599cb641..a8ac801c2 100644 --- a/tensorflow_text/core/kernels/sentencepiece/double_array_trie.h +++ b/tensorflow_text/core/kernels/sentencepiece/double_array_trie.h @@ -12,121 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +#ifndef TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ +#define TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +#include "tensorflow/core/kernels/text/sentencepiece/double_array_trie.h" - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ -#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ - -#include <functional> -#include <vector> - -#include "tensorflow_text/core/kernels/sentencepiece/config_generated.h" -#include "tensorflow_text/core/kernels/sentencepiece/utils.h" - -namespace tensorflow { -namespace text { -namespace sentencepiece { - -// A trie node specifies a node in the tree, either an intermediate node or -// a leaf node. -// A leaf node contains the id as an int of the string match. This id is encoded -// in the lower 31 bits, thus the number of distinct ids is 2^31. -// An intermediate node has an associated label and an offset to its children. -// The label is encoded in the least significant byte and must match the input -// character during matching. - -// A memory mappable trie, compatible with Darts::DoubleArray. -class DoubleArrayTrie { - public: - struct Match { - Match() {} - Match(int id, int match_length) : id(id), match_length(match_length) {} - int id = -1; - int match_length = -1; - bool empty() const { return match_length == -1; } - bool operator==(const Match& m) const { - return m.id == id && m.match_length == match_length; - } - }; - - // nodes and nodes_length specify the array of the nodes of the trie. - explicit DoubleArrayTrie(const flatbuffers::Vector<uint32_t>* nodes) - : nodes_(nodes) {} - - // Finds matches that are prefixes of a string. - template <typename callback> - void IteratePrefixMatches(const utils::string_view& input, - callback update_fn) const; - - // Finds the longest prefix match of a string. - Match LongestPrefixMatch(const utils::string_view& input) const { - Match match; - IteratePrefixMatches(input, [&match](const Match& m) { match = m; }); - return match; - } - - private: - // Returns whether a node as a leaf as a child. - bool has_leaf(uint32_t i) const { return ((*nodes_)[i]) & 0x100; } - - // Returns a value associated with a node. Available when a node is a leaf. - int value(uint32_t i) const { - return static_cast<int>(((*nodes_)[i]) & 0x7fffffff); - } - - // Returns a label associated with a node. - // A leaf node will have the MSB set and thus return an invalid label. - int32_t label(uint32_t i) const { return ((*nodes_)[i]) & 0x800000ff; } - - // Returns offset to children. - int32_t offset(uint32_t i) const { - const uint32_t node = (*nodes_)[i]; - return (node >> 10) << ((node & 0x200) >> 6); - } - - const flatbuffers::Vector<uint32_t>* nodes_; -}; - -template <typename callback> -void DoubleArrayTrie::IteratePrefixMatches(const utils::string_view& input, - callback update_fn) const { - if (nodes_->size() == 0) { - return; - } - uint32_t pos = offset(0); - for (int i = 0; i < input.length(); ++i) { - pos ^= static_cast<unsigned char>(input.at(i)); - if (pos < 0 || pos >= nodes_->size() || label(pos) != input.at(i)) { - // No match, exit. - return; - } - const bool node_has_leaf = has_leaf(pos); - pos ^= offset(pos); - if (pos < 0 || pos >= nodes_->size()) { - // We can get here only if the structure is corrupted. - return; - } - if (node_has_leaf) { - update_fn(Match(value(pos), i + 1)); - } - } -} - -} // namespace sentencepiece -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ +#endif // TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ diff --git a/tensorflow_text/core/kernels/sentencepiece/double_array_trie_builder.cc b/tensorflow_text/core/kernels/sentencepiece/double_array_trie_builder.cc deleted file mode 100644 index 7e5bdae64..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/double_array_trie_builder.cc +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_text/core/kernels/sentencepiece/double_array_trie_builder.h" - -#include <algorithm> -#include <memory> - -#include "include/darts.h" - -namespace tensorflow { -namespace text { -namespace sentencepiece { - -std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data) { - std::vector<int> ids; - ids.reserve(data.size()); - for (int i = 0; i < data.size(); ++i) { - ids.push_back(i); - } - return BuildTrie(data, ids); -} - -std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data, - const std::vector<int>& ids) { - // We make strong assumptions about binary structure of trie. - struct OneElement { - OneElement(const std::string* key_, int index_) - : key(key_), index(index_) {} - const std::string* key; - int index; - bool operator<(const OneElement& el) const { return *key < *el.key; } - }; - std::vector<OneElement> elements; - elements.reserve(data.size()); - auto data_iterator = std::begin(data); - auto ids_iterator = std::begin(ids); - for (; data_iterator != std::end(data) && ids_iterator != std::end(ids); - ++data_iterator, ++ids_iterator) { - elements.emplace_back(&(*data_iterator), *ids_iterator); - } - // Sort by keys. - std::sort(elements.begin(), elements.end()); - - // Create vectors to build the trie. - std::vector<const char*> strings; - std::vector<int32_t> indexes; - strings.reserve(data.size()); - indexes.reserve(data.size()); - for (const auto& el : elements) { - strings.push_back(el.key->c_str()); - indexes.push_back(el.index); - } - auto trie = std::make_unique<Darts::DoubleArray>(); - trie->build(data.size(), const_cast<char**>(&strings[0]), nullptr, - &indexes[0]); - // We make strong assumptions about internal Darts trie structure: - // - it is a vector of 32 bit signed integers - // - the "array" is the only one structure that contains all information about - // the trie. - const uint32_t* trie_data = static_cast<const uint32_t*>(trie->array()); - return std::vector<uint32_t>(trie_data, trie_data + trie->size()); -} - -} // namespace sentencepiece -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/sentencepiece/double_array_trie_builder.h b/tensorflow_text/core/kernels/sentencepiece/double_array_trie_builder.h index 1e585f99a..a3c444398 100644 --- a/tensorflow_text/core/kernels/sentencepiece/double_array_trie_builder.h +++ b/tensorflow_text/core/kernels/sentencepiece/double_array_trie_builder.h @@ -12,42 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +#ifndef TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ +#define TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +#include "tensorflow/core/kernels/text/sentencepiece/double_array_trie_builder.h" - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ -#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ - -#include <string> -#include <vector> - -#include "tensorflow_text/core/kernels/sentencepiece/config_generated.h" -#include "tensorflow_text/core/kernels/sentencepiece/utils.h" - -namespace tensorflow { -namespace text { -namespace sentencepiece { - -std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data, - const std::vector<int>& ids); - -// A variant where ids are indexes in data. -std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data); - -} // namespace sentencepiece -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ +#endif // TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ diff --git a/tensorflow_text/core/kernels/sentencepiece/double_array_trie_test.cc b/tensorflow_text/core/kernels/sentencepiece/double_array_trie_test.cc deleted file mode 100644 index 118a0573a..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/double_array_trie_test.cc +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_text/core/kernels/sentencepiece/double_array_trie.h" - -#include <gmock/gmock.h> -#include <gtest/gtest.h> -#include "tensorflow_text/core/kernels/sentencepiece/double_array_trie_builder.h" -#include "tensorflow_text/core/kernels/sentencepiece/encoder_config_generated.h" - -namespace tensorflow { -namespace text { -namespace sentencepiece { - -TEST(DoubleArrayTrieTest, Match) { - flatbuffers::FlatBufferBuilder builder(1024); - const std::vector<std::string> test_strings = {"A", "AAX", "AA", "B"}; - const auto trie_vector = builder.CreateVector(BuildTrie(test_strings)); - TrieBuilder trie_builder(builder); - trie_builder.add_nodes(trie_vector); - const auto pieces = trie_builder.Finish(); - EncoderConfigBuilder ecb(builder); - ecb.add_pieces(pieces); - FinishEncoderConfigBuffer(builder, ecb.Finish()); - const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); - DoubleArrayTrie dat(config->pieces()->nodes()); - EXPECT_EQ(dat.LongestPrefixMatch(utils::string_view("AAL")), - DoubleArrayTrie::Match(2, 2)); - - std::vector<DoubleArrayTrie::Match> matches; - dat.IteratePrefixMatches( - utils::string_view("AAXL"), - [&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); }); - EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(0, 1), - DoubleArrayTrie::Match(2, 2), - DoubleArrayTrie::Match(1, 3))); -} - -TEST(DoubleArrayTrieTest, ComplexMatch) { - flatbuffers::FlatBufferBuilder builder(1024); - const std::vector<std::string> test_strings = {"\xe2\x96\x81the", ",", "s", - "\xe2\x96\x81Hello"}; - const std::vector<int> test_ids = {0, 5, 10, 15}; - const auto trie_vector = - builder.CreateVector(BuildTrie(test_strings, test_ids)); - TrieBuilder trie_builder(builder); - trie_builder.add_nodes(trie_vector); - const auto pieces = trie_builder.Finish(); - EncoderConfigBuilder ecb(builder); - ecb.add_pieces(pieces); - FinishEncoderConfigBuffer(builder, ecb.Finish()); - const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); - DoubleArrayTrie dat(config->pieces()->nodes()); - - std::vector<DoubleArrayTrie::Match> matches; - dat.IteratePrefixMatches( - utils::string_view("\xe2\x96\x81Hello"), - [&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); }); - EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(15, 8))); -} - -} // namespace sentencepiece -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/sentencepiece/encoder_config.fbs b/tensorflow_text/core/kernels/sentencepiece/encoder_config.fbs deleted file mode 100644 index 1c98ddde1..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/encoder_config.fbs +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -include "config.fbs"; - -namespace tensorflow.text.sentencepiece; - -table EncoderConfig { - // Version of the encoder. - version: EncoderVersion = SENTENCE_PIECE; - start_code: int32 = 0; - end_code: int32 = 0; - - unknown_code: int32 = -1; - // Weight of "unknown code" when encoding. "Penalty" because it usually has a - // big negative weight,less than any other sentencepiece. - unknown_penalty: float = 0; - - // The offset for encoding, usually used when codes with low codes are reserved - // for some special needs. - encoding_offset: int32; - - // String pieces for encoding. - pieces: Trie; - pieces_scores: [float]; - - // Normalization related parameters. - remove_extra_whitespaces: bool; - - // Add a whitespace prefix before encoding. - add_dummy_prefix: bool; - - // Escape whitespaces during encoding so the decoder can restore them exactly as - // in the input. - escape_whitespaces: bool; - - // Normalization parameters. - normalized_prefixes: Trie; - normalized_replacements: [byte]; -} - -root_type EncoderConfig; diff --git a/tensorflow_text/core/kernels/sentencepiece/model_converter.cc b/tensorflow_text/core/kernels/sentencepiece/model_converter.cc deleted file mode 100644 index bdaaff375..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/model_converter.cc +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_text/core/kernels/sentencepiece/model_converter.h" -#include <tuple> - -#include "absl/status/status.h" -#include "absl/strings/str_replace.h" -#include "src/sentencepiece_model.pb.h" -#include "tensorflow_text/core/kernels/sentencepiece/decoder_config_generated.h" -#include "tensorflow_text/core/kernels/sentencepiece/double_array_trie_builder.h" -#include "tensorflow_text/core/kernels/sentencepiece/encoder_config_generated.h" -#include "tensorflow_text/core/kernels/sentencepiece/sentencepiece_constants.h" - -namespace tensorflow { -namespace text { -namespace sentencepiece { - -std::tuple<std::vector<uint32_t>, std::vector<int8_t>> -DecodePrecompiledCharsmap( - const ::sentencepiece::NormalizerSpec& normalizer_spec) { - // This function "undoes" encoding done by - // sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap. - if (normalizer_spec.precompiled_charsmap().empty()) { - return std::make_tuple(std::vector<uint32_t>(), std::vector<int8_t>()); - } - const char* precompiled_map = normalizer_spec.precompiled_charsmap().data(); - const uint32_t trie_size = - *reinterpret_cast<const uint32_t*>(precompiled_map); - const uint32_t* trie_ptr = - reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t)); - const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>( - precompiled_map + sizeof(uint32_t) + trie_size); - const int normalized_size = normalizer_spec.precompiled_charsmap().length() - - sizeof(uint32_t) - trie_size; - return std::make_tuple( - std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)), - std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size)); -} - -absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer( - const std::string& model_config_str, int encoding_offset) { - ::sentencepiece::ModelProto model_config; - if (!model_config.ParseFromString(model_config_str)) { - return absl::InvalidArgumentError( - "Invalid configuration, can't parse SentencePiece model config " + - model_config.InitializationErrorString()); - } - // Convert sentencepieces. - std::vector<std::string> pieces; - pieces.reserve(model_config.pieces_size()); - std::vector<float> scores; - scores.reserve(model_config.pieces_size()); - std::vector<int> ids; - ids.reserve(model_config.pieces_size()); - float min_score = 0.0; - int index = 0; - for (const auto& piece : model_config.pieces()) { - switch (piece.type()) { - case ::sentencepiece::ModelProto::SentencePiece::NORMAL: - case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED: - pieces.push_back(piece.piece()); - ids.push_back(index); - if (piece.score() < min_score) { - min_score = piece.score(); - } - break; - case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN: - case ::sentencepiece::ModelProto::SentencePiece::CONTROL: - case ::sentencepiece::ModelProto::SentencePiece::BYTE: - // Ignore unknown and control codes. - break; - default: - return absl::InvalidArgumentError("Invalid SentencePiece piece type " + - piece.piece()); - } - scores.push_back(piece.score()); - ++index; - } - flatbuffers::FlatBufferBuilder builder(1024); - const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids)); - const auto pieces_score_vector = builder.CreateVector(scores); - TrieBuilder pieces_trie_builder(builder); - pieces_trie_builder.add_nodes(pieces_trie_vector); - const auto pieces_trie_fbs = pieces_trie_builder.Finish(); - - // Converting normalization. - const auto normalization = - DecodePrecompiledCharsmap(model_config.normalizer_spec()); - const auto normalization_trie = std::get<0>(normalization); - const auto normalization_strings = std::get<1>(normalization); - const auto normalization_trie_vector = - builder.CreateVector(normalization_trie); - TrieBuilder normalization_trie_builder(builder); - normalization_trie_builder.add_nodes(normalization_trie_vector); - const auto normalization_trie_fbs = normalization_trie_builder.Finish(); - const auto normalization_strings_fbs = - builder.CreateVector(normalization_strings); - - EncoderConfigBuilder ecb(builder); - ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE); - ecb.add_start_code(model_config.trainer_spec().bos_id()); - ecb.add_end_code(model_config.trainer_spec().eos_id()); - ecb.add_unknown_code(model_config.trainer_spec().unk_id()); - ecb.add_unknown_penalty(min_score - kUnkPenalty); - ecb.add_encoding_offset(encoding_offset); - ecb.add_pieces(pieces_trie_fbs); - ecb.add_pieces_scores(pieces_score_vector); - ecb.add_remove_extra_whitespaces( - model_config.normalizer_spec().remove_extra_whitespaces()); - ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix()); - ecb.add_escape_whitespaces( - model_config.normalizer_spec().escape_whitespaces()); - ecb.add_normalized_prefixes(normalization_trie_fbs); - ecb.add_normalized_replacements(normalization_strings_fbs); - FinishEncoderConfigBuffer(builder, ecb.Finish()); - return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize()); -} - -absl::StatusOr<std::string> -ConvertSentencepieceModelToFlatBufferForDecoder( - const std::string& model_config_str, int encoding_offset) { - ::sentencepiece::ModelProto model_config; - if (!model_config.ParseFromString(model_config_str)) { - return absl::InvalidArgumentError( - "Invalid configuration, can't parse SentencePiece model config " + - model_config.InitializationErrorString()); - } - flatbuffers::FlatBufferBuilder builder(1024); - // Collect sentencepieces. - std::vector<std::string> pieces; - for (const auto& piece : model_config.pieces()) { - // In the original library all pieces processing is done during decoding. - // Because it is independent from context or parameters we can do it in - // advance here. - switch (piece.type()) { - case ::sentencepiece::ModelProto::SentencePiece::NORMAL: - case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED: - pieces.push_back( - absl::StrReplaceAll(piece.piece(), {{kSpaceSymbol, " "}})); - break; - case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN: - pieces.push_back( - kDefaultUnknownSymbol); // Always decode with the default unknown. - break; - default: - pieces.push_back(""); - } - } - const auto pieces_fbs = builder.CreateVectorOfStrings(pieces); - DecoderConfigBuilder decb(builder); - - decb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE); - decb.add_encoding_offset(encoding_offset); - decb.add_decode_pieces(pieces_fbs); - decb.add_remove_dummy_prefix( - model_config.normalizer_spec().add_dummy_prefix()); - - FinishDecoderConfigBuffer(builder, decb.Finish()); - return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), - builder.GetSize()); -} - -int GetVocabularySize(const std::string& model_string) { - const EncoderConfig* config = GetEncoderConfig(model_string.data()); - return config->pieces_scores()->size() + config->encoding_offset(); -} - -std::string ConvertSentencepieceModel(const std::string& model_string) { - const auto result = ConvertSentencepieceModelToFlatBuffer(model_string); - // TODO(mgubin): Propogate error to the Python code and throw correct - // exception. - assert(result.status().ok()); - return result.value(); -} - -std::string ConvertSentencepieceModelForDecoder( - const std::string& model_string) { - const auto result = - ConvertSentencepieceModelToFlatBufferForDecoder(model_string); - // TODO(mgubin): Propogate error to the Python code and throw correct - // exception. - assert(result.status().ok()); - return result.value(); -} - -} // namespace sentencepiece -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/sentencepiece/model_converter.h b/tensorflow_text/core/kernels/sentencepiece/model_converter.h index 716e989b4..faea9f55d 100644 --- a/tensorflow_text/core/kernels/sentencepiece/model_converter.h +++ b/tensorflow_text/core/kernels/sentencepiece/model_converter.h @@ -12,53 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +#ifndef TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_MODEL_CONVERTER_H_ +#define TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_MODEL_CONVERTER_H_ -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +#include "tensorflow/core/kernels/text/sentencepiece/model_converter.h" - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_MODEL_CONVERTER_H_ -#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_MODEL_CONVERTER_H_ -#include <string> - -#include "absl/status/statusor.h" - -namespace tensorflow { -namespace text { -namespace sentencepiece { - -// Converts Sentencepiece configuration to flatbuffer format. -// encoding_offset is used by some encoders that combine different encodings. -absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer( - const std::string& model_config_str, int encoding_offset = 0); - -// Converts Sentencepiece configuration to flatbuffer format for encoder. -// encoding_offset is used by some encoders that combine different encodings. -absl::StatusOr<std::string> -ConvertSentencepieceModelToFlatBufferForDecoder( - const std::string& model_config_str, int encoding_offset = 0); - -// The functions that are provided for the Python wrapper. -std::string ConvertSentencepieceModel(const std::string& model_string); -std::string ConvertSentencepieceModelForDecoder( - const std::string& model_string); - -// Returns size of a vocabulary from Sentencepiece configuration in flatbuffer -// format. -int GetVocabularySize(const std::string& model_string); - -} // namespace sentencepiece -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_MODEL_CONVERTER_H_ +#endif // TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_MODEL_CONVERTER_H_ diff --git a/tensorflow_text/core/kernels/sentencepiece/native.bzl b/tensorflow_text/core/kernels/sentencepiece/native.bzl deleted file mode 100644 index 0d0d2184d..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/native.bzl +++ /dev/null @@ -1,89 +0,0 @@ -"""Build definitions supporting platform-independent native build.""" - -load("//third_party/bazel_skylib/lib:selects.bzl", "selects") -load("//third_party/tensorflow:tensorflow.bzl", "tf_copts", "tf_opts_nortti_if_android") - -def micore_if(android, ios = [], default = []): - """Helper to create a select. - - Args: - android: what to return if compiling for Android. - ios: what to return if compiling for iOS. - default: what to return otherwise. - Returns: - the `android` list for Android compilation and the - `default` list otherwise. - """ - return select({ - "//tools/cc_target_os:android": android, - "//tools/cc_target_os:apple": ios, - "//conditions:default": default, - }) - -def micore_tf_copts(): - """C options for Tensorflow builds. - - Returns: - a list of copts which must be used by each cc_library which - refers to Tensorflow. Enables the library to compile both for - Android and for Google3. - """ - return tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [ - "-Wno-narrowing", - "-Wno-sign-compare", - "-Wno-overloaded-virtual", - ] + micore_if( - android = [ - # Set a define so Tensorflow's register_types.h - # adopts to support a rich set of types, to be pruned by - # selective registration. - "-DSUPPORT_SELECTIVE_REGISTRATION", - # Selective registration uses constexprs with recursive - # string comparisons; that can lead to compiler errors, so - # we increase the constexpr recursion depth. - "-fconstexpr-depth=1024", - ], - ) + selects.with_or({ - # If building for armeabi-v7a, and if compilation_mode is 'fastbuild' - # or 'dbg' then forcefully add -Oz to the list compiler options. - # Without it, some TF dependencies can't build (b/112286436). If - # compilation_mode is 'opt' then rely on the toolchain default. - ( - "//intelligence/micore/tools/build:armeabi_v7a_and_fastbuild", - "//intelligence/micore/tools/build:armeabi_v7a_and_dbg", - ): ["-Oz"], - "//conditions:default": [], - }) - -def micore_tf_deps(): - """Dependencies for Tensorflow builds. - - Returns: - list of dependencies which must be used by each cc_library - which refers to Tensorflow. Enables the library to compile both for - Android and for Google3. Use this macro instead of directly - declaring dependencies on Tensorflow. - """ - return micore_if( - android = [ - # Link to library which does not contain any ops. - # tf:portable_tensorflow_lib_lite tensorflow dep, - "//third_party/gemmlowp:eight_bit_int_gemm", - "//third_party/fft2d", - ], - ios = [ - # tf:portable_tensorflow_lib tensorflow dep, - "//third_party/gemmlowp:eight_bit_int_gemm", - "//third_party/fft2d", - ], - default = [ - # Standard references for Tensorflow when building for non-mobile, plain Google3. We use - # an indirection via the alias targets below, to facilitate whitelisting these deps in - # the mobile license presubmit checks. - "//intelligence/micore/tools/build:tensorflow_core_cpu", - "//intelligence/micore/tools/build:tensorflow_core_framework", - "//intelligence/micore/tools/build:tensorflow_core_lib", - "//intelligence/micore/tools/build:tensorflow_core_protos_all_cc", - "//intelligence/micore/tools/build:tensorflow_core_tensorflow", - ], - ) diff --git a/tensorflow_text/core/kernels/sentencepiece/native.bzl.oss b/tensorflow_text/core/kernels/sentencepiece/native.bzl.oss deleted file mode 100644 index c12530abf..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/native.bzl.oss +++ /dev/null @@ -1,87 +0,0 @@ -"""Build definitions supporting platform-independent native build.""" - -load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_copts", "tf_opts_nortti_if_android") -load("@bazel_skylib//lib:selects.bzl", "selects") - -def micore_if(android, ios = [], default = []): - """Helper to create a select. - - Args: - android: what to return if compiling for Android. - ios: what to return if compiling for iOS. - default: what to return otherwise. - Returns: - the `android` list for Android compilation and the - `default` list otherwise. - """ - return select({ - ":android": android, - ":apple": ios, - "//conditions:default": default, - }) - -def micore_tf_copts(): - """C options for Tensorflow builds. - - Returns: - a list of copts which must be used by each cc_library which - refers to Tensorflow. Enables the library to compile both for - Android and for Linux. - """ - return tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [ - "-Wno-narrowing", - "-Wno-sign-compare", - "-Wno-overloaded-virtual", - ] + micore_if( - android = [ - # Set a define so Tensorflow's register_types.h - # adopts to support a rich set of types, to be pruned by - # selective registration. - "-DSUPPORT_SELECTIVE_REGISTRATION", - # Selective registration uses constexprs with recursive - # string comparisons; that can lead to compiler errors, so - # we increase the constexpr recursion depth. - "-fconstexpr-depth=1024", - ], - ) + selects.with_or({ - # If building for armeabi-v7a, and if compilation_mode is 'fastbuild' - # or 'dbg' then forcefully add -Oz to the list compiler options. - # Without it, some TF dependencies can't build (b/112286436). If - # compilation_mode is 'opt' then rely on the toolchain default. - ( - ":armeabi_v7a_and_fastbuild", - ":armeabi_v7a_and_dbg", - ): ["-Oz"], - "//conditions:default": [], - }) - -def micore_tf_deps(): - """Dependencies for Tensorflow builds. - - Returns: - list of dependencies which must be used by each cc_library - which refers to Tensorflow. Enables the library to compile both for - Android and for Linux. Use this macro instead of directly - declaring dependencies on Tensorflow. - """ - return micore_if( - android = [ - # Link to library which does not contain any ops. - "@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite", - "@gemmlowp//:eight_bit_int_gemm", - "@fft2d//:fft2d", - ], - ios = [ - "@org_tensorflow//tensorflow/core:portable_tensorflow_lib", - "@gemmlowp//:eight_bit_int_gemm", - "@fft2d//:fft2d", - ], - default = [ - # Standard references for Tensorflow when building for Linux. We use - # an indirection via the alias targets below, to facilitate whitelisting - # these deps in the mobile license presubmit checks. - "@release_or_nightly//:tensorflow_libtensorflow_framework", - "@release_or_nightly//:tensorflow_tf_header_lib", - ], - - ) diff --git a/tensorflow_text/core/kernels/sentencepiece/optimized_decoder.cc b/tensorflow_text/core/kernels/sentencepiece/optimized_decoder.cc deleted file mode 100644 index 397349c58..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/optimized_decoder.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_text/core/kernels/sentencepiece/optimized_decoder.h" - -#include <string> -#include <tuple> - -#include "tensorflow_text/core/kernels/sentencepiece/decoder_config_generated.h" -#include "tensorflow_text/core/kernels/sentencepiece/double_array_trie.h" - -namespace tensorflow { -namespace text { -namespace sentencepiece { - -DecoderResult DecodeString(const std::vector<int>& encoded, - const void* config_buffer) { - DecoderResult result; - - // Get the config from the buffer. - const DecoderConfig* config = GetDecoderConfig(config_buffer); - if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) { - result.type = DecoderResultType::WRONG_CONFIG; - return result; - } - bool remove_dummy_prefix = config->remove_dummy_prefix(); - const auto config_pieces = config->decode_pieces(); - for (const auto code : encoded) { - const int real_code = code - config->encoding_offset(); - if (real_code >= config_pieces->size()) { - result.type = DecoderResultType::INVALID_INPUT; - return result; - } - const auto& piece_text = config_pieces->GetAsString(real_code); - const char* piece_str = piece_text->c_str(); - if (remove_dummy_prefix && *piece_str == ' ') { - ++piece_str; - } - result.decoded.append(piece_str); - remove_dummy_prefix = false; - } - // TODO(mgubin): Denormalize the string, haven't seen any Sentencepiece model - // with a denormalizer. - return result; -} - -} // namespace sentencepiece -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/sentencepiece/optimized_decoder.h b/tensorflow_text/core/kernels/sentencepiece/optimized_decoder.h index 8513d8e06..50d1fa4c5 100644 --- a/tensorflow_text/core/kernels/sentencepiece/optimized_decoder.h +++ b/tensorflow_text/core/kernels/sentencepiece/optimized_decoder.h @@ -12,51 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +#ifndef TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_OPTIMIZED_DECODER_H_ +#define TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_OPTIMIZED_DECODER_H_ -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +#include "tensorflow/core/kernels/text/sentencepiece/optimized_decoder.h" - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_DECODER_H_ -#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_DECODER_H_ - -// Sentencepiece decoder optimized with memmapped model. - -#include <string> -#include <vector> - -namespace tensorflow { -namespace text { -namespace sentencepiece { - -enum class DecoderResultType { - SUCCESS = 0, - WRONG_CONFIG = 1, - INVALID_INPUT = 2 -}; - -struct DecoderResult { - DecoderResultType type = DecoderResultType::SUCCESS; - std::string decoded; -}; - -// Decodes one string from a vector of id. Takes the configuration as a -// type-erased buffer. -DecoderResult DecodeString(const std::vector<int>& encoded, - const void* config_buffer); - -} // namespace sentencepiece -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_DECODER_H_ +#endif // TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_OPTIMIZED_DECODER_H_ diff --git a/tensorflow_text/core/kernels/sentencepiece/optimized_decoder_test.cc b/tensorflow_text/core/kernels/sentencepiece/optimized_decoder_test.cc deleted file mode 100644 index 600941e76..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/optimized_decoder_test.cc +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_text/core/kernels/sentencepiece/optimized_decoder.h" - -#include <fstream> - -#include "file/base/path.h" -#include <gmock/gmock.h> -#include <gtest/gtest.h> -#include "absl/strings/str_format.h" -#include "src/sentencepiece.proto.h" -#include "src/sentencepiece_processor.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/lite/kernels/test_util.h" -#include "tensorflow_text/core/kernels/sentencepiece/model_converter.h" - -namespace tensorflow { -namespace text { -namespace sentencepiece { - -namespace internal { - -absl::Status TFReadFileToString(const std::string& filepath, - std::string* data) { - return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath, - data); -} - -absl::Status StdReadFileToString(const std::string& filepath, - std::string* data) { - std::ifstream infile(filepath); - if (!infile.is_open()) { - return absl::NotFoundError( - absl::StrFormat("Error when opening %s", filepath)); - } - std::string contents((std::istreambuf_iterator<char>(infile)), - (std::istreambuf_iterator<char>())); - data->append(contents); - infile.close(); - return absl::OkStatus(); -} - -} // namespace internal - -namespace { - -static char kConfigFilePath[] = - "/tensorflow_text/python/ops/test_data/" - "fast_sentencepiece.model"; - -TEST(OptimizedEncoder, ConfigConverter) { - std::string config; - - auto status = internal::TFReadFileToString( - file::JoinPath(::testing::SrcDir(), kConfigFilePath), &config); - ASSERT_TRUE(status.ok()); - - ::sentencepiece::SentencePieceProcessor processor; - ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok()); - const auto converted_model = ConvertSentencepieceModelForDecoder(config); - const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95"); - ::sentencepiece::SentencePieceText reference_encoded; - ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok()); - - std::vector<int> encoded_vector; - encoded_vector.reserve(reference_encoded.pieces_size()); - for (const auto& piece : reference_encoded.pieces()) { - encoded_vector.push_back(piece.id()); - } - std::string ref_decoded; - ASSERT_TRUE(processor.Decode(encoded_vector, &ref_decoded).ok()); - const auto decoded = DecodeString(encoded_vector, converted_model.data()); - ASSERT_EQ(decoded.type, DecoderResultType::SUCCESS); - ASSERT_EQ(ref_decoded, decoded.decoded); -} -} // namespace - -} // namespace sentencepiece -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/sentencepiece/optimized_encoder.cc b/tensorflow_text/core/kernels/sentencepiece/optimized_encoder.cc deleted file mode 100644 index 9602684d0..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/optimized_encoder.cc +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_text/core/kernels/sentencepiece/optimized_encoder.h" - -#include <algorithm> -#include <tuple> - -#include "tensorflow_text/core/kernels/sentencepiece/double_array_trie.h" -#include "tensorflow_text/core/kernels/sentencepiece/encoder_config_generated.h" - -namespace tensorflow { -namespace text { -namespace sentencepiece { -namespace { - -const char kSpaceSymbol[] = "\xe2\x96\x81"; - -template <typename processing_callback> -std::tuple<std::string, std::vector<int>> process_string( - const std::string& input, const std::vector<int>& offsets, - const processing_callback& pc) { - std::string result_string; - result_string.reserve(input.size()); - std::vector<int> result_offsets; - result_offsets.reserve(offsets.size()); - for (int i = 0, j = 0; i < input.size();) { - auto result = pc(input.data() + i, input.size() - i); - auto consumed = std::get<0>(result); - auto new_string = std::get<1>(result); - if (consumed == 0) { - // Skip the current byte and move forward. - result_string.push_back(input[i]); - result_offsets.push_back(offsets[j]); - i++; - j++; - continue; - } - result_string.append(new_string.data(), new_string.length()); - for (int i = 0; i < new_string.length(); ++i) { - result_offsets.push_back(offsets[j]); - } - j += consumed; - i += consumed; - } - return std::make_tuple(result_string, result_offsets); -} - -inline char is_whitespace(char c) { - return c == ' ' || c == '\t' || c == '\r' || c == '\n'; -} - -std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data, - int len) { - if (len == 0 || !is_whitespace(*data)) { - return std::make_tuple(0, utils::string_view(nullptr, 0)); - } - int num_consumed = 1; - for (; num_consumed < len && is_whitespace(data[num_consumed]); - ++num_consumed) { - } - return num_consumed > 1 - ? std::make_tuple(num_consumed, utils::string_view(" ", 1)) - : std::make_tuple(0, utils::string_view(nullptr, 0)); -} - -std::tuple<int, utils::string_view> find_replacement( - const char* data, int len, const DoubleArrayTrie& dat, - const flatbuffers::Vector<int8_t>& replacements) { - const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len)); - if (!max_match.empty()) { - // Because flatbuffer byte is signed char which is not the same as char, - // there is the reinterpret_cast here. - const char* replaced_string_ptr = - reinterpret_cast<const char*>(replacements.data() + max_match.id); - return std::make_tuple(max_match.match_length, - utils::string_view(replaced_string_ptr)); - } - return std::make_tuple(0, utils::string_view(nullptr, 0)); -} -} // namespace - -std::tuple<std::string, std::vector<int>> NormalizeString( - const std::string& in_string, const EncoderConfig& config) { - std::vector<int> output_offsets; - std::string result = in_string; - output_offsets.reserve(in_string.length()); - for (int i = 0; i < in_string.length(); ++i) { - output_offsets.push_back(i); - } - if (in_string.empty()) { - return std::make_tuple(result, output_offsets); - } - if (config.add_dummy_prefix()) { - result.insert(result.begin(), ' '); - output_offsets.insert(output_offsets.begin(), 0); - } - // Greedely replace normalized_prefixes with normalized_replacements - if (config.normalized_prefixes() != nullptr && - config.normalized_replacements() != nullptr) { - const DoubleArrayTrie normalized_prefixes_matcher( - config.normalized_prefixes()->nodes()); - const auto norm_replace = [&config, &normalized_prefixes_matcher]( - const char* data, int len) { - return find_replacement(data, len, normalized_prefixes_matcher, - *config.normalized_replacements()); - }; - std::tie(result, output_offsets) = - process_string(result, output_offsets, norm_replace); - } - if (config.remove_extra_whitespaces()) { - std::tie(result, output_offsets) = - process_string(result, output_offsets, remove_extra_whitespaces); - if (!result.empty() && is_whitespace(result.back())) { - result.pop_back(); - output_offsets.pop_back(); - } - } - if (config.escape_whitespaces()) { - const auto replace_whitespaces = [](const char* data, int len) { - if (len > 0 && is_whitespace(*data)) { - return std::make_tuple(1, utils::string_view(kSpaceSymbol)); - } - return std::make_tuple(0, utils::string_view(nullptr, 0)); - }; - std::tie(result, output_offsets) = - process_string(result, output_offsets, replace_whitespaces); - } - - return std::make_tuple(result, output_offsets); -} - -EncoderResult EncodeNormalizedString(const std::string& str, - const std::vector<int>& offsets, - const EncoderConfig& config, bool add_bos, - bool add_eos, bool reverse) { - const DoubleArrayTrie piece_matcher(config.pieces()->nodes()); - const flatbuffers::Vector<float>* piece_scores = config.pieces_scores(); - const int unknown_code = config.unknown_code(); - const float unknown_penalty = config.unknown_penalty(); - struct LatticeElement { - float score = 0; - int code = -1; - int prev_position = -1; - LatticeElement(float score_, int code_, int prev_position_) - : score(score_), code(code_), prev_position(prev_position_) {} - LatticeElement() {} - }; - const int length = str.length(); - std::vector<LatticeElement> lattice(length + 1); - for (int i = 0; i < length; ++i) { - if (i > 0 && lattice[i].prev_position < 0) { - // This state is unreachable. - continue; - } - if (unknown_code >= 0) { - // Put unknown code. - const float penalized_score = lattice[i].score + unknown_penalty; - const int pos = i + 1; - LatticeElement& current_element = lattice[pos]; - if (current_element.prev_position < 0 || - current_element.score < penalized_score) { - current_element = LatticeElement( - penalized_score, unknown_code, - // If the current state is already reached by unknown code, merge - // states. - lattice[i].code == unknown_code ? lattice[i].prev_position : i); - } - } - auto lattice_update = [&lattice, i, - piece_scores](const DoubleArrayTrie::Match& m) { - LatticeElement& target_element = lattice[i + m.match_length]; - const float score = lattice[i].score + (*piece_scores)[m.id]; - if (target_element.prev_position < 0 || target_element.score < score) { - target_element = LatticeElement(score, m.id, i); - } - }; - piece_matcher.IteratePrefixMatches( - utils::string_view(str.data() + i, length - i), lattice_update); - } - - EncoderResult result; - if (add_eos) { - result.codes.push_back(config.end_code()); - result.offsets.push_back(length); - } - if (lattice[length].prev_position >= 0) { - for (int pos = length; pos > 0;) { - auto code = lattice[pos].code; - if (code != config.unknown_code()) { - code += config.encoding_offset(); - } - result.codes.push_back(code); - pos = lattice[pos].prev_position; - result.offsets.push_back(offsets[pos]); - } - } - if (add_bos) { - result.codes.push_back(config.start_code()); - result.offsets.push_back(0); - } - if (!reverse) { - std::reverse(result.codes.begin(), result.codes.end()); - std::reverse(result.offsets.begin(), result.offsets.end()); - } - return result; -} - -EncoderResult EncodeString(const std::string& string, const void* config_buffer, - bool add_bos, bool add_eos, bool reverse) { - // Get the config from the buffer. - const EncoderConfig* config = GetEncoderConfig(config_buffer); - if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) { - EncoderResult result; - result.type = EncoderResultType::WRONG_CONFIG; - return result; - } - std::string normalized_string; - std::vector<int> offsets; - std::tie(normalized_string, offsets) = NormalizeString(string, *config); - return EncodeNormalizedString(normalized_string, offsets, *config, add_bos, - add_eos, reverse); -} - -} // namespace sentencepiece -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/sentencepiece/optimized_encoder.h b/tensorflow_text/core/kernels/sentencepiece/optimized_encoder.h index d1ca949a6..324219aa7 100644 --- a/tensorflow_text/core/kernels/sentencepiece/optimized_encoder.h +++ b/tensorflow_text/core/kernels/sentencepiece/optimized_encoder.h @@ -12,53 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +#ifndef TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ +#define TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +#include "tensorflow/core/kernels/text/sentencepiece/optimized_encoder.h" - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ -#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ - -// Sentencepiece encoder optimized with memmapped model. - -#include <string> -#include <tuple> -#include <vector> - -#include "tensorflow_text/core/kernels/sentencepiece/encoder_config_generated.h" - -namespace tensorflow { -namespace text { -namespace sentencepiece { - -enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 }; - -struct EncoderResult { - EncoderResultType type = EncoderResultType::SUCCESS; - std::vector<int> codes; - std::vector<int> offsets; -}; -std::tuple<std::string, std::vector<int>> NormalizeString( - const std::string& in_string, const EncoderConfig& config); - -// Encodes one string and returns ids and offsets. Takes the configuration as a -// type-erased buffer. -EncoderResult EncodeString(const std::string& string, const void* config_buffer, - bool add_bos, bool add_eos, bool reverse); - -} // namespace sentencepiece -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ +#endif // TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ diff --git a/tensorflow_text/core/kernels/sentencepiece/optimized_encoder_test.cc b/tensorflow_text/core/kernels/sentencepiece/optimized_encoder_test.cc deleted file mode 100644 index ecab756f4..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/optimized_encoder_test.cc +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_text/core/kernels/sentencepiece/optimized_encoder.h" - -#include <fstream> - -#include "file/base/path.h" -#include <gmock/gmock.h> -#include <gtest/gtest.h> -#include "absl/status/status.h" -#include "absl/strings/str_format.h" -#include "src/sentencepiece.proto.h" -#include "src/sentencepiece_processor.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/lite/kernels/test_util.h" -#include "tensorflow_text/core/kernels/sentencepiece/double_array_trie_builder.h" -#include "tensorflow_text/core/kernels/sentencepiece/encoder_config_generated.h" -#include "tensorflow_text/core/kernels/sentencepiece/model_converter.h" - -namespace tensorflow { -namespace text { -namespace sentencepiece { - -namespace internal { - -absl::Status TFReadFileToString(const std::string& filepath, - std::string* data) { - return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath, - data); -} - -absl::Status StdReadFileToString(const std::string& filepath, - std::string* data) { - std::ifstream infile(filepath); - if (!infile.is_open()) { - return absl::NotFoundError( - absl::StrFormat("Error when opening %s", filepath)); - } - std::string contents((std::istreambuf_iterator<char>(infile)), - (std::istreambuf_iterator<char>())); - data->append(contents); - infile.close(); - return absl::OkStatus(); -} -} // namespace internal - -namespace { - -static char kConfigFilePath[] = - "/tensorflow_text/python/ops/test_data/" - "fast_sentencepiece.model"; - -TEST(OptimizedEncoder, NormalizeStringWhitestpaces) { - flatbuffers::FlatBufferBuilder builder(1024); - EncoderConfigBuilder ecb(builder); - ecb.add_remove_extra_whitespaces(true); - ecb.add_add_dummy_prefix(true); - ecb.add_escape_whitespaces(true); - FinishEncoderConfigBuffer(builder, ecb.Finish()); - const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); - { - const auto result = NormalizeString("x y", *config); - const auto res_string = std::get<0>(result); - const auto offsets = std::get<1>(result); - EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y"); - EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3)); - } - { - const auto result = NormalizeString("\tx y\n", *config); - const auto res_string = std::get<0>(result); - const auto offsets = std::get<1>(result); - EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y"); - EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4)); - } -} - -TEST(OptimizedEncoder, NormalizeStringReplacement) { - flatbuffers::FlatBufferBuilder builder(1024); - const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA"}; - const char norm_replacements[] = "A1\0A2\0A3\0A4"; - const auto trie_vector = - builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9})); - const auto norm_r = builder.CreateVector<int8_t>( - reinterpret_cast<const signed char*>(norm_replacements), - sizeof(norm_replacements)); - TrieBuilder trie_builder(builder); - trie_builder.add_nodes(trie_vector); - const auto norm_p = trie_builder.Finish(); - EncoderConfigBuilder ecb(builder); - ecb.add_remove_extra_whitespaces(false); - ecb.add_normalized_prefixes(norm_p); - ecb.add_normalized_replacements(norm_r); - FinishEncoderConfigBuffer(builder, ecb.Finish()); - const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); - { - const auto result = NormalizeString("ABAABAAABAAAA", *config); - const auto res_string = std::get<0>(result); - const auto offsets = std::get<1>(result); - EXPECT_EQ(res_string, "A1BA2BA3BA4"); - EXPECT_THAT(offsets, - ::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9)); - } -} - -TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) { - flatbuffers::FlatBufferBuilder builder(1024); - const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA", - "X"}; - const char norm_replacements[] = "A1\0A2\0A3\0A4\0 "; - const auto trie_vector = - builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12})); - const auto norm_r = builder.CreateVector<int8_t>( - reinterpret_cast<const signed char*>(norm_replacements), - sizeof(norm_replacements)); - TrieBuilder trie_builder(builder); - trie_builder.add_nodes(trie_vector); - const auto norm_p = trie_builder.Finish(); - EncoderConfigBuilder ecb(builder); - ecb.add_remove_extra_whitespaces(true); - ecb.add_normalized_prefixes(norm_p); - ecb.add_normalized_replacements(norm_r); - FinishEncoderConfigBuffer(builder, ecb.Finish()); - const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); - { - const auto result = NormalizeString("XXABAABAAABAAAA", *config); - const auto res_string = std::get<0>(result); - const auto offsets = std::get<1>(result); - EXPECT_EQ(res_string, " A1BA2BA3BA4"); - EXPECT_THAT(offsets, - ::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11)); - } -} - -TEST(OptimizedEncoder, ConfigConverter) { - std::string config; - auto status = internal::TFReadFileToString( - file::JoinPath(::testing::SrcDir(), kConfigFilePath), &config); - ASSERT_TRUE(status.ok()); - - ::sentencepiece::SentencePieceProcessor processor; - ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok()); - const auto converted_model = ConvertSentencepieceModel(config); - const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95"); - const auto encoded = - EncodeString(test_string, converted_model.data(), false, false, false); - ASSERT_EQ(encoded.codes.size(), encoded.offsets.size()); - - ::sentencepiece::SentencePieceText reference_encoded; - ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok()); - EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size()); - for (int i = 0; i < encoded.codes.size(); ++i) { - EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id()); - EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin()); - } -} - -} // namespace -} // namespace sentencepiece -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/sentencepiece/py_tflite_registerer.cc b/tensorflow_text/core/kernels/sentencepiece/py_tflite_registerer.cc deleted file mode 100644 index e5ae73622..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/py_tflite_registerer.cc +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow_text/core/kernels/sentencepiece/py_tflite_registerer.h" - -namespace tflite { -namespace ops { -namespace custom { -TfLiteRegistration* Register_FAST_SENTENCEPIECE_TOKENIZER(); -TfLiteRegistration* Register_FAST_SENTENCEPIECE_DETOKENIZER(); - -namespace text { - -extern "C" void AddFastSentencepieceTokenize( - tflite::MutableOpResolver* resolver) { - resolver->AddCustom( - "TFText>FastSentencepieceTokenize", - ::tflite::ops::custom::Register_FAST_SENTENCEPIECE_TOKENIZER()); -} - -extern "C" void AddFastSentencepieceDetokenize( - tflite::MutableOpResolver* resolver) { - resolver->AddCustom( - "TFText>FastSentencepieceDetokenize", - ::tflite::ops::custom::Register_FAST_SENTENCEPIECE_DETOKENIZER()); -} - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow_text/core/kernels/sentencepiece/py_tflite_registerer.h b/tensorflow_text/core/kernels/sentencepiece/py_tflite_registerer.h index 6a64a6b7a..4a0eb6c7e 100644 --- a/tensorflow_text/core/kernels/sentencepiece/py_tflite_registerer.h +++ b/tensorflow_text/core/kernels/sentencepiece/py_tflite_registerer.h @@ -12,40 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +#ifndef TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_ +#define TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_ -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +#include "tensorflow/core/kernels/text/sentencepiece/py_tflite_registerer.h" - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_ -#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_ -#include "tensorflow/lite/mutable_op_resolver.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -// C-function that is called from the Python Wrapper. -extern "C" void AddFastSentencepieceTokenize( - tflite::MutableOpResolver *resolver); - -extern "C" void AddFastSentencepieceDetokenize( - tflite::MutableOpResolver *resolver); - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite - -#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_ +#endif // TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_ diff --git a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_constants.h b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_constants.h index f0d95d1c8..cf3917f2c 100644 --- a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_constants.h +++ b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_constants.h @@ -12,44 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +#ifndef TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ +#define TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +#include "tensorflow/core/kernels/text/sentencepiece/sentencepiece_constants.h" - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ -#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ - -namespace tensorflow { -namespace text { -namespace sentencepiece { - -// The constant is copied from -// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc -constexpr float kUnkPenalty = 10.0; - -// These constants are copied from -// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc -// -// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK). -constexpr char kSpaceSymbol[] = "\xe2\x96\x81"; - -// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK), -// since this character can be useful both for user and -// developer. We can easily figure out that <unk> is emitted. -constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 "; - -} // namespace sentencepiece -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ +#endif // TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ diff --git a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer.h b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer.h index 24b41fc8c..898ba8b3e 100644 --- a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer.h +++ b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer.h @@ -12,34 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +#ifndef TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_SENTENCEPIECE_DETOKENIZER_H_ +#define TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_SENTENCEPIECE_DETOKENIZER_H_ -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +#include "tensorflow/core/kernels/text/sentencepiece/sentencepiece_detokenizer.h" - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_DETOKENIZER_H_ -#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_DETOKENIZER_H_ - -// Constants are shared between TF and TFLite SentencepieceTokenizer kernels. -namespace tensorflow { -namespace text { -constexpr int kSPModelIndex = 0; -constexpr int kInputIndex = 1; -constexpr int kInputSplits = 2; -constexpr int kAddBOSInput = 4; -constexpr int kAddEOSInput = 5; -constexpr int kReverseInput = 6; -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_DETOKENIZER_H_ +#endif // TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_SENTENCEPIECE_DETOKENIZER_H_ diff --git a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_kernel.cc b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_kernel.cc deleted file mode 100644 index 3338ca38e..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_kernel.cc +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow_text/core/kernels/sentencepiece/optimized_decoder.h" -#include "tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer.h" - -namespace tensorflow { -namespace text { - -template <typename Tsplits> -class TFSentencepieceDetokenizerOp : public tensorflow::OpKernel { - public: - explicit TFSentencepieceDetokenizerOp(tensorflow::OpKernelConstruction* ctx) - : OpKernel(ctx) {} - void Compute(tensorflow::OpKernelContext* ctx) override { - const auto& model_tensor = ctx->input(kSPModelIndex); - const auto& input_values_tensor = ctx->input(kInputIndex); - const auto input_values_flat = - input_values_tensor.flat<tensorflow::int32>(); - const auto& input_splits_tensor = ctx->input(kInputSplits); - const auto input_splits_flat = input_splits_tensor.flat<Tsplits>(); - const int num_of_sentences = input_splits_flat.size() - 1; - Tensor* output_tensor = nullptr; - OP_REQUIRES_OK(ctx, - ctx->allocate_output(0, {num_of_sentences}, &output_tensor)); - auto output_flat = output_tensor->flat<tensorflow::tstring>(); - std::vector<int> codes_for_split; - int input_offset = 0; - for (int i = 0; i < num_of_sentences; i++) { - // Create a vector of int32 from input according to spans. - const int split_size = input_splits_flat(i + 1) - input_splits_flat(i); - codes_for_split.clear(); - codes_for_split.reserve(split_size); - for (int j = 0; j < split_size; ++j) { - codes_for_split.push_back(input_values_flat(input_offset++)); - } - const auto res = sentencepiece::DecodeString( - codes_for_split, model_tensor.data()); - OP_REQUIRES(ctx, res.type == sentencepiece::DecoderResultType::SUCCESS, - absl::Status(static_cast<absl::StatusCode>( - absl::StatusCode::kInternal), - "Sentencepiece conversion failed")); - output_flat(i) = res.decoded; - } - } -}; -} // namespace text -} // namespace tensorflow - -REGISTER_KERNEL_BUILDER( - Name("TFText>FastSentencepieceDetokenize") - .Device(tensorflow::DEVICE_CPU) - .TypeConstraint<tensorflow::int32>("Tsplits"), - tensorflow::text::TFSentencepieceDetokenizerOp<tensorflow::int32>); -REGISTER_KERNEL_BUILDER( - Name("TFText>FastSentencepieceDetokenize") - .Device(tensorflow::DEVICE_CPU) - .TypeConstraint<tensorflow::int64>("Tsplits"), - tensorflow::text::TFSentencepieceDetokenizerOp<tensorflow::int64>); diff --git a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_tflite.cc b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_tflite.cc deleted file mode 100644 index 89e8a4723..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_tflite.cc +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -/** - * Sentencepiece tflite detokenizer implementation. - */ -#include <algorithm> -#include <iterator> - -#include "flatbuffers/flexbuffers.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/context.h" -#include "tensorflow/lite/kernels/internal/tensor.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/model.h" -#include "tensorflow/lite/string_util.h" -#include "tensorflow_text/core/kernels/sentencepiece/optimized_decoder.h" -#include "tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { -namespace sentencepiece { -namespace detokenizer { - -constexpr int kOutputValuesInd = 0; -// Initializes text encoder object from serialized parameters. -void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/, - size_t /*length*/) { - return nullptr; -} -void Free(TfLiteContext* /*context*/, void* /*buffer*/) {} - -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - // TODO(mgubin): Add checks for input and output tensors. - TfLiteTensor& output_values = - context->tensors[node->outputs->data[kOutputValuesInd]]; - SetTensorToDynamic(&output_values); - // TODO(mgubin): Check input types. - - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const TfLiteTensor& model_tensor = - context->tensors[node->inputs->data[tensorflow::text::kSPModelIndex]]; - const auto model_buffer_data = model_tensor.data.data; - const TfLiteTensor& input_encoded = - context->tensors[node->inputs->data[tensorflow::text::kInputIndex]]; - const int32_t* input_encoded_data = input_encoded.data.i32; - const TfLiteTensor& input_splits = - context->tensors[node->inputs->data[tensorflow::text::kInputSplits]]; - const int num_of_sentences = NumElements(input_splits.dims) - 1; - const int32_t* input_splits_data = input_splits.data.i32; - - DynamicBuffer buf; - - std::vector<int> codes_for_split; - int input_offset = 0; - for (int i = 0; i < num_of_sentences; i++) { - // Create a vector of int32 from input according to spans. - const int split_size = input_splits_data[i + 1] - input_splits_data[i]; - codes_for_split.clear(); - std::copy(input_encoded_data + input_offset, - input_encoded_data + input_offset + split_size, - std::back_inserter(codes_for_split)); - const auto res = tensorflow::text::sentencepiece::DecodeString( - codes_for_split, model_buffer_data); - TF_LITE_ENSURE_MSG( - context, - res.type == tensorflow::text::sentencepiece::DecoderResultType::SUCCESS, - "Sentencepiece decoding failed"); - buf.AddString(res.decoded.data(), res.decoded.length()); - input_offset += split_size; - } - TfLiteTensor& output_values = - context->tensors[node->outputs->data[kOutputValuesInd]]; - buf.WriteToTensor(&output_values, nullptr); - return kTfLiteOk; -} -} // namespace detokenizer -} // namespace sentencepiece -} // namespace text - -TfLiteRegistration* Register_FAST_SENTENCEPIECE_DETOKENIZER() { - static TfLiteRegistration r = { - text::sentencepiece::detokenizer::Initialize, - text::sentencepiece::detokenizer::Free, - text::sentencepiece::detokenizer::Prepare, - text::sentencepiece::detokenizer::Eval}; - return &r; -} - -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer.h b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer.h index 8ce03f69b..423fda05f 100644 --- a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer.h +++ b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer.h @@ -12,34 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +#ifndef TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_H_ +#define TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_H_ -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +#include "tensorflow/core/kernels/text/sentencepiece/sentencepiece_tokenizer.h" - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_H_ -#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_H_ - -// Constants are shared between TF and TFLite SentencepieceTokenizer kernels. -namespace tensorflow { -namespace text { - -constexpr int kSPModelIndex = 0; -constexpr int kInputIndex = 1; -constexpr int kAddBOSInput = 4; -constexpr int kAddEOSInput = 5; -constexpr int kReverseInput = 6; -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_H_ +#endif // TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_H_ diff --git a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer_kernel.cc b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer_kernel.cc deleted file mode 100644 index 22a5beaf6..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer_kernel.cc +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include <cstdint> -#include <iterator> -#include <limits> -#include <vector> - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow_text/core/kernels/sentencepiece/optimized_encoder.h" -#include "tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer.h" - -namespace tensorflow { -namespace text{ - -class TFSentencepieceOp : public tensorflow::OpKernel { - public: - explicit TFSentencepieceOp(tensorflow::OpKernelConstruction* ctx) - : OpKernel(ctx) {} - void Compute(tensorflow::OpKernelContext* ctx) override { - const auto& model_tensor = ctx->input(kSPModelIndex); - const auto& input_values_tensor = ctx->input(kInputIndex); - const auto input_values_flat = - input_values_tensor.flat<tensorflow::tstring>(); - const int64_t num_of_input_values = input_values_flat.size(); - - const auto& add_bos_tensor = ctx->input(kAddBOSInput); - const bool add_bos = add_bos_tensor.scalar<bool>()(); - const auto& add_eos_tensor = ctx->input(kAddEOSInput); - const bool add_eos = add_eos_tensor.scalar<bool>()(); - const auto& reverse_tensor = ctx->input(kReverseInput); - const bool reverse = reverse_tensor.scalar<bool>()(); - - std::vector<int32> encoded; - std::vector<int32> splits; - for (int i = 0; i < num_of_input_values; ++i) { - const auto res = sentencepiece::EncodeString( - input_values_flat(i), model_tensor.data(), add_bos, add_eos, reverse); - OP_REQUIRES(ctx, res.type == sentencepiece::EncoderResultType::SUCCESS, - absl::Status(static_cast<absl::StatusCode>( - absl::StatusCode::kInternal), - "Sentencepiece conversion failed")); - std::copy(res.codes.begin(), res.codes.end(), - std::back_inserter(encoded)); - splits.emplace_back(encoded.size()); - } - tensorflow::Tensor* output_values_tensor = nullptr; - tensorflow::Tensor* output_splits_tensor = nullptr; - OP_REQUIRES(ctx, encoded.size() < std::numeric_limits<int32_t>::max(), - errors::InvalidArgument( - "Encoded input must contain less than 2^31 characters.")); - OP_REQUIRES( - ctx, splits.size() + 1 < std::numeric_limits<int32_t>::max(), - errors::InvalidArgument("Splits tensor is limited to 2^31-1 values.")); - OP_REQUIRES_OK( - ctx, ctx->allocate_output(0, {static_cast<int32_t>(encoded.size())}, - &output_values_tensor)); - OP_REQUIRES_OK( - ctx, ctx->allocate_output(1, {static_cast<int32_t>(splits.size()) + 1}, - &output_splits_tensor)); - - auto values_tensor_flat = output_values_tensor->vec<int32>(); - auto splits_tensor_flat = output_splits_tensor->vec<int32>(); - for (int32_t i = 0; i < encoded.size(); ++i) { - values_tensor_flat(i) = encoded[i]; - } - splits_tensor_flat(0) = 0; - for (int32_t i = 0; i < splits.size(); ++i) { - splits_tensor_flat(i + 1) = splits[i]; - } - } -}; - -} // namespace text -} // namespace tensorflow -REGISTER_KERNEL_BUILDER( - Name("TFText>FastSentencepieceTokenize").Device(tensorflow::DEVICE_CPU), - tensorflow::text::TFSentencepieceOp); diff --git a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer_tflite.cc b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer_tflite.cc deleted file mode 100644 index ddce5bd48..000000000 --- a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer_tflite.cc +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** - * Sentencepiece tflite tokenizer implementation. - */ -#include "tensorflow_text/core/kernels/sentencepiece/optimized_encoder.h" -#include "tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer.h" -#include "flatbuffers/flexbuffers.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/context.h" -#include "tensorflow/lite/kernels/internal/tensor.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/model.h" -#include "tensorflow/lite/string_util.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { -namespace sentencepiece { -namespace tokenizer { - -constexpr int kOutputValuesInd = 0; -constexpr int kOutputSplitsInd = 1; - -namespace { -TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) { - TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size()); - int index = 0; - for (const int size : sizes) { - array_size->data[index++] = size; - } - return array_size; -} -} // namespace - -// Initializes text encoder object from serialized parameters. -void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/, - size_t /*length*/) { - return nullptr; -} -void Free(TfLiteContext* /*context*/, void* /*buffer*/) {} - -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - // TODO(mgubin): Add checks for input and output tensors. - TfLiteTensor& output_values = - context->tensors[node->outputs->data[kOutputValuesInd]]; - SetTensorToDynamic(&output_values); - - TfLiteTensor& output_splits = - context->tensors[node->outputs->data[kOutputSplitsInd]]; - SetTensorToDynamic(&output_splits); - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const TfLiteTensor& model_tensor = - context->tensors[node->inputs->data[tensorflow::text::kSPModelIndex]]; - const auto model_buffer_data = model_tensor.data.data; - const TfLiteTensor& input_text = - context->tensors[node->inputs->data[tensorflow::text::kInputIndex]]; - - const TfLiteTensor add_bos_tensor = - context->tensors[node->inputs->data[tensorflow::text::kAddBOSInput]]; - const bool add_bos = add_bos_tensor.data.b[0]; - const TfLiteTensor add_eos_tensor = - context->tensors[node->inputs->data[tensorflow::text::kAddEOSInput]]; - const bool add_eos = add_eos_tensor.data.b[0]; - const TfLiteTensor reverse_tensor = - context->tensors[node->inputs->data[tensorflow::text::kReverseInput]]; - const bool reverse = reverse_tensor.data.b[0]; - - std::vector<int32> encoded; - std::vector<int32> splits; - const int num_strings = tflite::GetStringCount(&input_text); - for (int i = 0; i < num_strings; ++i) { - const auto strref = tflite::GetString(&input_text, i); - const auto res = tensorflow::text::sentencepiece::EncodeString( - std::string(strref.str, strref.len), model_buffer_data, add_bos, - add_eos, reverse); - TF_LITE_ENSURE_MSG( - context, - res.type == tensorflow::text::sentencepiece::EncoderResultType::SUCCESS, - "Sentencepiece conversion failed"); - std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded)); - splits.emplace_back(encoded.size()); - } - - TfLiteTensor& output_values = - context->tensors[node->outputs->data[kOutputValuesInd]]; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor( - context, &output_values, - CreateSizeArray({static_cast<int>(encoded.size())}))); - int32_t* output_values_flat = output_values.data.i32; - std::copy(encoded.begin(), encoded.end(), output_values_flat); - TfLiteTensor& output_splits = - context->tensors[node->outputs->data[kOutputSplitsInd]]; - TF_LITE_ENSURE_OK( - context, context->ResizeTensor( - context, &output_splits, - CreateSizeArray({static_cast<int>(splits.size() + 1)}))); - int32_t* output_splits_flat = output_splits.data.i32; - *output_splits_flat = 0; - std::copy(splits.begin(), splits.end(), output_splits_flat + 1); - return kTfLiteOk; -} -} // namespace tokenizer -} // namespace sentencepiece -} // namespace text - -TfLiteRegistration* Register_FAST_SENTENCEPIECE_TOKENIZER() { - static TfLiteRegistration r = { - text::sentencepiece::tokenizer::Initialize, - text::sentencepiece::tokenizer::Free, - text::sentencepiece::tokenizer::Prepare, - text::sentencepiece::tokenizer::Eval}; - return &r; -} - -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow_text/core/kernels/sentencepiece/utils.h b/tensorflow_text/core/kernels/sentencepiece/utils.h index fb9d850e0..e6d84c924 100644 --- a/tensorflow_text/core/kernels/sentencepiece/utils.h +++ b/tensorflow_text/core/kernels/sentencepiece/utils.h @@ -12,67 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +#ifndef TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_UTILS_H_ +#define TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_UTILS_H_ -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +#include "tensorflow/core/kernels/text/sentencepiece/utils.h" - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_UTILS_H_ -#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_UTILS_H_ - -#include <ostream> -#include <string> - -namespace tensorflow { -namespace text { -namespace sentencepiece { - -// AOSP and WASM doesn't support string_view, -// we put here a minimal re-implementation. -namespace utils { - -class string_view { - public: - explicit string_view(const std::string& s) - : str_(s.data()), len_(s.length()) {} - string_view(const char* str, int len) : str_(str), len_(len) {} - // A constructor from c string. - explicit string_view(const char* s) : str_(s), len_(strlen(s)) {} - - int length() const { return len_; } - const char* data() const { return str_; } - bool empty() const { return len_ == 0; } - unsigned char at(int i) const { return str_[i]; } - - private: - const char* str_ = nullptr; - const int len_ = 0; -}; - -inline std::ostream& operator<<(std::ostream& os, const string_view& sv) { - os << std::string(sv.data(), sv.length()); - return os; -} -inline bool operator==(const string_view& view1, const string_view& view2) { - if (view1.length() != view2.length()) { - return false; - } - return memcmp(view1.data(), view2.data(), view1.length()) == 0; -} - -} // namespace utils -} // namespace sentencepiece -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_UTILS_H_ +#endif // TENSORFLOW_TEXT_CORE_KERNELS_SENTENCEPIECE_UTILS_H_ diff --git a/tensorflow_text/core/kernels/sentencepiece_kernels.cc b/tensorflow_text/core/kernels/sentencepiece_kernels.cc deleted file mode 100644 index a1f57bc19..000000000 --- a/tensorflow_text/core/kernels/sentencepiece_kernels.cc +++ /dev/null @@ -1,736 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "absl/base/attributes.h" -#include "absl/base/optimization.h" -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/meta/type_traits.h" -#include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "src/sentencepiece_model.pb.h" -#include "src/sentencepiece.pb.h" -#include "src/sentencepiece_processor.h" -#include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/dataset_stateful_op_allowlist.h" -#include "tensorflow/core/framework/device_base.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/work_sharder.h" - -namespace tensorflow { -namespace text { - -namespace { - -// Our resource object that will hold the SentencePiece processor. -struct SentencepieceResource : public ResourceBase { - sentencepiece::SentencePieceProcessor processor; - int64 memory_used; - bool add_bos = false; - bool add_eos = false; - bool reverse = false; - mutable absl::Mutex mu; - - string DebugString() const override { return "Sentencepiece Resource"; } - - int64 MemoryUsed() const override { return memory_used; } - - bool SameOptions(bool add_bos, bool add_eos, bool reverse) const { - return (add_bos == this->add_bos) && (add_eos == this->add_eos) && - (reverse == this->reverse); - } - - Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override { - absl::ReaderMutexLock l(&mu); - // We set use_node_name_sharing with a unique node name so that the resource - // can outlive the kernel. This means that the lifetime of the re-created - // resource will be tied to the lifetime of the resource manager it is - // created in. - static std::atomic<int64> counter(0); - std::string unique_node_name = strings::StrCat( - "SentencepieceResourceFromGraphDef", "/", counter.fetch_add(1)); - std::string model = processor.model_proto().SerializeAsString(); - *out = ops::SourceOp( - "SentencepieceOp", - builder->opts() - .WithName(unique_node_name) - .WithAttr("model", model) - .WithAttr("use_node_name_sharing", true)); - return absl::OkStatus(); - } -}; - -// According to .../tensorflow/core/util/work_sharder.cc, this values determines -// how much to shard. It assumes each cost unit is 1ns, and the minimum cost -// per shard is 10000 (10us). -// TODO(broken) Determine a medium cost of a call to the SentencePiece processor -constexpr int64 kCostPerUnit = 10000; - -::tensorflow::Status ToTFStatus(const sentencepiece::util::Status& s) { - if (s.ok()) return ::tensorflow::Status(); - return ::tensorflow::Status(static_cast<::absl::StatusCode>(s.code()), - ::tensorflow::string(s.message())); -} - -template <typename T> -T GetPieceOrId(const sentencepiece::SentencePieceText::SentencePiece& sp); - -template <> -tensorflow::tstring GetPieceOrId<tensorflow::tstring>( - const sentencepiece::SentencePieceText::SentencePiece& sp) { - return sp.piece(); -} - -template <> -int32 GetPieceOrId<int32>( - const sentencepiece::SentencePieceText::SentencePiece& sp) { - return sp.id(); -} - -tensorflow::Status HandleExtraOptions(OpKernelContext* ctx, - SentencepieceResource* sp) { - const Tensor* add_bos_tensor = nullptr; - TF_RETURN_IF_ERROR(ctx->input("add_bos", &add_bos_tensor)); - const bool add_bos = add_bos_tensor->scalar<bool>()(); - - const Tensor* add_eos_tensor = nullptr; - TF_RETURN_IF_ERROR(ctx->input("add_eos", &add_eos_tensor)); - const bool add_eos = add_eos_tensor->scalar<bool>()(); - - const Tensor* reverse_tensor = nullptr; - TF_RETURN_IF_ERROR(ctx->input("reverse", &reverse_tensor)); - const bool reverse = reverse_tensor->scalar<bool>()(); - - { - // Because we expect most of the time no change in these options, we grab - // the reader lock once and do a quick check first. - absl::ReaderMutexLock l(&sp->mu); - if (sp->SameOptions(add_bos, add_eos, reverse)) { - return absl::OkStatus(); - } - } - - absl::WriterMutexLock lock(&sp->mu); - if (sp->SameOptions(add_bos, add_eos, reverse)) { - return absl::OkStatus(); - } - string options; - sp->add_bos = add_bos; - if (sp->add_bos) { - absl::StrAppend(&options, "bos"); - } - sp->add_eos = add_eos; - if (sp->add_eos) { - if (!options.empty()) { - absl::StrAppend(&options, ":"); - } - absl::StrAppend(&options, "eos"); - } - sp->reverse = reverse; - if (sp->reverse) { - if (!options.empty()) { - absl::StrAppend(&options, ":"); - } - absl::StrAppend(&options, "reverse"); - } - - TF_RETURN_IF_ERROR(ToTFStatus(sp->processor.SetEncodeExtraOptions(options))); - TF_RETURN_IF_ERROR(ToTFStatus(sp->processor.SetDecodeExtraOptions(options))); - - return absl::OkStatus(); -} - -} // namespace - -class SentencepieceOp : public OpKernel { - public: - explicit SentencepieceOp(OpKernelConstruction* ctx) - : OpKernel(ctx), sp_set_(false) { - OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_STRING, - tensorflow::TensorShape({2}), &sp_)); - OP_REQUIRES_OK( - ctx, ctx->GetAttr("use_node_name_sharing", &use_node_name_sharing_)); - } - - ~SentencepieceOp() override { - // If the table object was not shared, delete it. - if (sp_set_ && cinfo_.resource_is_private_to_kernel()) { - if (!cinfo_.resource_manager() - ->template Delete<SentencepieceResource>(cinfo_.container(), - cinfo_.name()) - .ok()) { - // Do nothing; the resource may have been deleted by session resets. - } - } - } - - void Compute(OpKernelContext* ctx) override { - absl::MutexLock lock(&mu_); - - if (!sp_set_) { - OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), - use_node_name_sharing_)); - } - - auto creator = - [ctx, this](SentencepieceResource** resource) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - SentencepieceResource* sp = new SentencepieceResource(); - - string model_proto_attr; - TF_RETURN_IF_ERROR( - GetNodeAttr(this->def(), "model", &model_proto_attr)); - - if (TF_PREDICT_FALSE(model_proto_attr.empty())) { - return Status(tensorflow::errors::InvalidArgument( - "Model argument must be specified.")); - } - // Loads serialized sentencepiece model proto to enable embedding - // the relatively small sentencepiece model proto into the - // tensorflow graph such that the tensorflow graph is - // self-contained. - TF_RETURN_IF_ERROR(ToTFStatus( - sp->processor.LoadFromSerializedProto(model_proto_attr))); - // TODO(broken): Determine a better computation of what the memory - // requirements for the processor are. - sp->memory_used = model_proto_attr.size(); - - if (ctx->track_allocations()) { - ctx->record_persistent_memory_allocation(sp->MemoryUsed()); - } - - *resource = sp; - return absl::OkStatus(); - }; - - // Register the ResourceType alias. - SentencepieceResource* resource = nullptr; - OP_REQUIRES_OK( - ctx, cinfo_.resource_manager() - ->template LookupOrCreate<SentencepieceResource>( - cinfo_.container(), cinfo_.name(), &resource, creator)); - core::ScopedUnref unref_me(resource); - - // Put a handle to resource in the output tensor (the other aliases will - // have the same handle). - Tensor* handle; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); - handle->scalar<ResourceHandle>()() = - MakeResourceHandle<SentencepieceResource>(ctx, cinfo_.container(), - cinfo_.name()); - sp_set_ = true; - } - - private: - absl::Mutex mu_; - Tensor sp_ ABSL_GUARDED_BY(mu_); - bool sp_set_ ABSL_GUARDED_BY(mu_); - ContainerInfo cinfo_; - bool use_node_name_sharing_; - TF_DISALLOW_COPY_AND_ASSIGN(SentencepieceOp); -}; - -REGISTER_KERNEL_BUILDER(Name("SentencepieceOp").Device(DEVICE_CPU), - tensorflow::text::SentencepieceOp); -ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("SentencepieceOp"); - -template <typename T, typename Tsplits> -class SentencepieceTokenizeOp : public OpKernel { - public: - explicit SentencepieceTokenizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - ctx->GetAttr("return_nbest", &return_nbest_).IgnoreError(); - } - - void Compute(OpKernelContext* ctx) override { - SentencepieceResource* sp; - const Tensor& resource_tensor = ctx->input(0); - ResourceHandle resource_handle(resource_tensor.scalar<ResourceHandle>()()); - OP_REQUIRES_OK( - ctx, ctx->resource_manager()->Lookup<SentencepieceResource>( - resource_handle.container(), resource_handle.name(), &sp)); - core::ScopedUnref unref_me(sp); - - const Tensor& input_values_tensor = ctx->input(1); - const auto input_values_flat = - input_values_tensor.flat<tensorflow::tstring>(); - const int64 num_of_input_values = input_values_flat.size(); - - const Tensor* nbest_size_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("nbest_size", &nbest_size_tensor)); - const Tensor* alpha_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("alpha", &alpha_tensor)); - - OP_REQUIRES_OK(ctx, HandleExtraOptions(ctx, sp)); - - if (return_nbest_) { - OP_REQUIRES(ctx, nbest_size_tensor->dims() == 0, - errors::InvalidArgument( - "When return_nbest is true nbest_size must " - "be a scalar; got", - nbest_size_tensor->shape().DebugString(), "instead")); - OP_REQUIRES(ctx, nbest_size_tensor->scalar<int32>()() >= 1, - errors::InvalidArgument( - "When return_nbest is true nbest_size must be >= 1; got ", - nbest_size_tensor->scalar<int32>()())); - } - - std::vector<std::vector<typename std::conditional< - std::is_same<T, tstring>::value, std::string, T>::type>> - tokens(return_nbest_ ? 0 : num_of_input_values); - std::vector<std::vector<std::vector<typename std::conditional< - std::is_same<T, tstring>::value, std::string, T>::type>>> - nbest_tokens(return_nbest_ ? num_of_input_values : 0); - if (num_of_input_values > 0) { - const bool return_nbest = return_nbest_; - const auto& worker_threads = - *(ctx->device()->tensorflow_cpu_worker_threads()); - ::tensorflow::Shard( - worker_threads.num_threads, // max parallelism - worker_threads.workers, // thread pool - num_of_input_values, // total number of data to process. - kCostPerUnit, // cost per unit - [ctx, sp, &input_values_flat, &tokens, &nbest_tokens, - &nbest_size_tensor, &alpha_tensor, - return_nbest](int64 start, int64 limit) { - absl::ReaderMutexLock lock(&sp->mu); - for (int i = start; i < limit; ++i) { - const int32 nbest_size = nbest_size_tensor->dims() == 1 - ? nbest_size_tensor->vec<int32>()(i) - : nbest_size_tensor->scalar<int32>()(); - if (return_nbest) { - OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.NBestEncode( - input_values_flat(i), nbest_size, - &nbest_tokens[i]))); - } else if (nbest_size == 0 || nbest_size == 1) { - OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.Encode( - input_values_flat(i), &tokens[i]))); - } else { - const float alpha = alpha_tensor->dims() == 1 - ? alpha_tensor->vec<float>()(i) - : alpha_tensor->scalar<float>()(); - OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.SampleEncode( - input_values_flat(i), nbest_size, alpha, - &tokens[i]))); - } - } - }); - } - - if (return_nbest_) { - for (auto& col : nbest_tokens) { - for (auto& row : col) { - tokens.push_back(std::move(row)); - } - } - nbest_tokens.clear(); - } - int64 total_tokens = 0; - for (auto& tokens_row : tokens) { - total_tokens += tokens_row.size(); - } - - Tensor* output_values_tensor = nullptr; - Tensor* output_splits_tensor = nullptr; - - OP_REQUIRES_OK( - ctx, ctx->allocate_output(0, {total_tokens}, &output_values_tensor)); - int64 splits_size = tokens.size() + 1; - OP_REQUIRES_OK( - ctx, ctx->allocate_output(1, {splits_size}, &output_splits_tensor)); - - auto values_tensor_flat = output_values_tensor->vec<T>(); - auto splits_tensor_flat = output_splits_tensor->vec<Tsplits>(); - - int i = 0; - splits_tensor_flat(0) = 0; - for (int row = 0; row < tokens.size(); ++row) { - for (int col = 0; col < tokens[row].size(); ++col, ++i) { - values_tensor_flat(i) = tokens[row][col]; - } - splits_tensor_flat(row + 1) = i; - } - } - - bool return_nbest_{false}; -}; - -REGISTER_KERNEL_BUILDER(Name("SentencepieceTokenizeOp") - .Device(DEVICE_CPU) - .TypeConstraint<int32>("out_type") - .TypeConstraint<int32>("Tsplits"), - SentencepieceTokenizeOp<int32, int32>); -REGISTER_KERNEL_BUILDER(Name("SentencepieceTokenizeOp") - .Device(DEVICE_CPU) - .TypeConstraint<tensorflow::tstring>("out_type") - .TypeConstraint<int32>("Tsplits"), - SentencepieceTokenizeOp<tensorflow::tstring, int32>); -REGISTER_KERNEL_BUILDER(Name("SentencepieceTokenizeOp") - .Device(DEVICE_CPU) - .TypeConstraint<int32>("out_type") - .TypeConstraint<int64>("Tsplits"), - SentencepieceTokenizeOp<int32, int64>); -REGISTER_KERNEL_BUILDER(Name("SentencepieceTokenizeOp") - .Device(DEVICE_CPU) - .TypeConstraint<tensorflow::tstring>("out_type") - .TypeConstraint<int64>("Tsplits"), - SentencepieceTokenizeOp<tensorflow::tstring, int64>); -ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("SentencepieceTokenizeOp"); - -template <typename T, typename Tsplits> -class SentencepieceTokenizeWithOffsetsOp : public OpKernel { - public: - explicit SentencepieceTokenizeWithOffsetsOp(OpKernelConstruction* ctx) - : OpKernel(ctx) { - ctx->GetAttr("return_nbest", &return_nbest_).IgnoreError(); - } - - void Compute(OpKernelContext* ctx) override { - SentencepieceResource* sp; - const Tensor& resource_tensor = ctx->input(0); - ResourceHandle resource_handle(resource_tensor.scalar<ResourceHandle>()()); - OP_REQUIRES_OK( - ctx, ctx->resource_manager()->Lookup<SentencepieceResource>( - resource_handle.container(), resource_handle.name(), &sp)); - core::ScopedUnref unref_me(sp); - - const Tensor& input_values_tensor = ctx->input(1); - const auto input_values_flat = - input_values_tensor.flat<tensorflow::tstring>(); - const int64 num_of_input_values = input_values_flat.size(); - - const Tensor* nbest_size_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("nbest_size", &nbest_size_tensor)); - const Tensor* alpha_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("alpha", &alpha_tensor)); - - OP_REQUIRES_OK(ctx, HandleExtraOptions(ctx, sp)); - - if (return_nbest_) { - OP_REQUIRES(ctx, nbest_size_tensor->dims() == 0, - errors::InvalidArgument( - "When return_nbest is true nbest_size must " - "be a scalar; got", - nbest_size_tensor->shape().DebugString(), "instead")); - OP_REQUIRES(ctx, nbest_size_tensor->scalar<int32>()() >= 1, - errors::InvalidArgument( - "When return_nbest is true nbest_size must be >= 1; got ", - nbest_size_tensor->scalar<int32>()())); - } - - std::vector<sentencepiece::SentencePieceText> results( - return_nbest_ ? 0 : num_of_input_values); - std::vector<sentencepiece::NBestSentencePieceText> nbest_results( - return_nbest_ ? num_of_input_values : 0); - if (num_of_input_values > 0) { - const bool return_nbest = return_nbest_; - const auto& worker_threads = - *(ctx->device()->tensorflow_cpu_worker_threads()); - ::tensorflow::Shard( - worker_threads.num_threads, // max parallelism - worker_threads.workers, // thread pool - num_of_input_values, // total number of data to process. - kCostPerUnit, - [ctx, sp, &input_values_flat, &results, &nbest_results, - &nbest_size_tensor, &alpha_tensor, - return_nbest](int64 start, int64 limit) { - absl::ReaderMutexLock lock(&sp->mu); - for (int i = start; i < limit; ++i) { - const int32 nbest_size = nbest_size_tensor->dims() == 1 - ? nbest_size_tensor->vec<int32>()(i) - : nbest_size_tensor->scalar<int32>()(); - if (return_nbest) { - OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.NBestEncode( - input_values_flat(i), nbest_size, - &nbest_results[i]))); - } else if (nbest_size == 0 || nbest_size == 1) { - OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.Encode( - input_values_flat(i), &results[i]))); - } else { - const float alpha = alpha_tensor->dims() == 1 - ? alpha_tensor->vec<float>()(i) - : alpha_tensor->scalar<float>()(); - OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.SampleEncode( - input_values_flat(i), nbest_size, alpha, - &results[i]))); - } - } - }); - } - - if (return_nbest_) { - for (auto& nbest : nbest_results) { - for (auto& result : nbest.nbests()) { - results.push_back(std::move(result)); - } - } - } - int64 total_tokens = 0; - for (auto& sp_result : results) { - total_tokens += sp_result.pieces_size(); - } - - Tensor* output_values_tensor = nullptr; - Tensor* output_splits_tensor = nullptr; - Tensor* output_starts_tensor = nullptr; - Tensor* output_limits_tensor = nullptr; - - OP_REQUIRES_OK( - ctx, ctx->allocate_output(0, {total_tokens}, &output_values_tensor)); - int64 splits_size = results.size() + 1; - OP_REQUIRES_OK( - ctx, ctx->allocate_output(1, {splits_size}, &output_splits_tensor)); - OP_REQUIRES_OK( - ctx, ctx->allocate_output(2, {total_tokens}, &output_starts_tensor)); - OP_REQUIRES_OK( - ctx, ctx->allocate_output(3, {total_tokens}, &output_limits_tensor)); - - auto values_tensor_flat = output_values_tensor->vec<T>(); - auto splits_tensor_flat = output_splits_tensor->vec<Tsplits>(); - auto starts_tensor_flat = output_starts_tensor->vec<int64>(); - auto limits_tensor_flat = output_limits_tensor->vec<int64>(); - - int i = 0; - splits_tensor_flat(0) = 0; - for (int row = 0; row < results.size(); ++row) { - for (auto& sp : results[row].pieces()) { - values_tensor_flat(i) = GetPieceOrId<T>(sp); - starts_tensor_flat(i) = sp.begin(); - limits_tensor_flat(i) = sp.end(); - ++i; - } - splits_tensor_flat(row + 1) = i; - } - } - - bool return_nbest_{false}; -}; - -REGISTER_KERNEL_BUILDER(Name("SentencepieceTokenizeWithOffsetsOp") - .Device(DEVICE_CPU) - .TypeConstraint<int32>("out_type") - .TypeConstraint<int32>("Tsplits"), - SentencepieceTokenizeWithOffsetsOp<int32, int32>); -REGISTER_KERNEL_BUILDER( - Name("SentencepieceTokenizeWithOffsetsOp") - .Device(DEVICE_CPU) - .TypeConstraint<tensorflow::tstring>("out_type") - .TypeConstraint<int32>("Tsplits"), - SentencepieceTokenizeWithOffsetsOp<tensorflow::tstring, int32>); -REGISTER_KERNEL_BUILDER(Name("SentencepieceTokenizeWithOffsetsOp") - .Device(DEVICE_CPU) - .TypeConstraint<int32>("out_type") - .TypeConstraint<int64>("Tsplits"), - SentencepieceTokenizeWithOffsetsOp<int32, int64>); -REGISTER_KERNEL_BUILDER( - Name("SentencepieceTokenizeWithOffsetsOp") - .Device(DEVICE_CPU) - .TypeConstraint<tensorflow::tstring>("out_type") - .TypeConstraint<int64>("Tsplits"), - SentencepieceTokenizeWithOffsetsOp<tensorflow::tstring, int64>); -ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("SentencepieceTokenizeWithOffsetsOp"); - -template <typename T, typename Tsplits> -class SentencepieceDetokenizeOp : public OpKernel { - public: - explicit SentencepieceDetokenizeOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - SentencepieceResource* sp; - const Tensor& resource_tensor = ctx->input(0); - ResourceHandle resource_handle(resource_tensor.scalar<ResourceHandle>()()); - OP_REQUIRES_OK( - ctx, ctx->resource_manager()->Lookup<SentencepieceResource>( - resource_handle.container(), resource_handle.name(), &sp)); - core::ScopedUnref unref_me(sp); - - const Tensor& input_values_tensor = ctx->input(1); - const auto input_values_flat = input_values_tensor.flat<T>(); - const Tensor& input_splits_tensor = ctx->input(2); - const auto input_splits_flat = input_splits_tensor.flat<Tsplits>(); - const int64 num_of_sentences = input_splits_flat.size() - 1; - - OP_REQUIRES_OK(ctx, HandleExtraOptions(ctx, sp)); - - Tensor* output_tensor; - OP_REQUIRES_OK(ctx, - ctx->allocate_output(0, {num_of_sentences}, &output_tensor)); - auto output_flat = output_tensor->flat<tensorflow::tstring>(); - - if (input_values_flat.size() > 0) { - const auto& worker_threads = - *(ctx->device()->tensorflow_cpu_worker_threads()); - ::tensorflow::Shard( - worker_threads.num_threads, // max parallelism - worker_threads.workers, // thread pool - num_of_sentences, // total number of data to process. - kCostPerUnit, - [ctx, sp, &input_values_flat, &input_splits_flat, &output_flat]( - int64 start, int64 limit) { - absl::ReaderMutexLock lock(&sp->mu); - for (int i = start; i < limit; ++i) { - if (i + 1 >= input_splits_flat.size()) { - ctx->CtxFailure(errors::OutOfRange("Invalid splits; ", i)); - return; - } - if (input_splits_flat(i) > input_values_flat.size()) { - ctx->CtxFailure(errors::OutOfRange( - "Splits and values do not match; split ", - input_splits_flat(i), "but values size is ", - input_values_flat.size())); - return; - } - const std::vector<typename std::conditional< - std::is_same<T, tstring>::value, std::string, T>::type> - pieces(&input_values_flat(input_splits_flat(i)), - &input_values_flat(input_splits_flat(i + 1))); - std::string output_flat_str; - OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.Decode( - pieces, &output_flat_str))); - output_flat(i) = output_flat_str; - } - }); - } - } -}; - -REGISTER_KERNEL_BUILDER(Name("SentencepieceDetokenizeOp") - .Device(DEVICE_CPU) - .TypeConstraint<int32>("T") - .TypeConstraint<int32>("Tsplits"), - SentencepieceDetokenizeOp<int32, int32>); -REGISTER_KERNEL_BUILDER(Name("SentencepieceDetokenizeOp") - .Device(DEVICE_CPU) - .TypeConstraint<tensorflow::tstring>("T") - .TypeConstraint<int32>("Tsplits"), - SentencepieceDetokenizeOp<tensorflow::tstring, int32>); -REGISTER_KERNEL_BUILDER(Name("SentencepieceDetokenizeOp") - .Device(DEVICE_CPU) - .TypeConstraint<int32>("T") - .TypeConstraint<int64>("Tsplits"), - SentencepieceDetokenizeOp<int32, int64>); -REGISTER_KERNEL_BUILDER(Name("SentencepieceDetokenizeOp") - .Device(DEVICE_CPU) - .TypeConstraint<tensorflow::tstring>("T") - .TypeConstraint<int64>("Tsplits"), - SentencepieceDetokenizeOp<tensorflow::tstring, int64>); -ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("SentencepieceDetokenizeOp"); - -class SentencepieceVocabSizeOp : public OpKernel { - public: - explicit SentencepieceVocabSizeOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - SentencepieceResource* sp; - const Tensor& resource_tensor = ctx->input(0); - ResourceHandle resource_handle(resource_tensor.scalar<ResourceHandle>()()); - OP_REQUIRES_OK( - ctx, ctx->resource_manager()->Lookup<SentencepieceResource>( - resource_handle.container(), resource_handle.name(), &sp)); - core::ScopedUnref unref_me(sp); - - Tensor* output_tensor; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output_tensor)); - output_tensor->scalar<int32>()() = sp->processor.GetPieceSize(); - } -}; - -REGISTER_KERNEL_BUILDER(Name("SentencepieceVocabSizeOp").Device(DEVICE_CPU), - SentencepieceVocabSizeOp); -ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("SentencepieceVocabSizeOp"); - -class SentencepieceIdToStringOp : public OpKernel { - public: - explicit SentencepieceIdToStringOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - SentencepieceResource* sp; - const Tensor& resource_tensor = ctx->input(0); - ResourceHandle resource_handle(resource_tensor.scalar<ResourceHandle>()()); - OP_REQUIRES_OK( - ctx, ctx->resource_manager()->Lookup<SentencepieceResource>( - resource_handle.container(), resource_handle.name(), &sp)); - core::ScopedUnref unref_me(sp); - - const Tensor& input_tensor = ctx->input(1); - const auto input_tensor_flat = input_tensor.flat<int32>(); - Tensor* output_tensor; - OP_REQUIRES_OK( - ctx, ctx->allocate_output(0, input_tensor.shape(), &output_tensor)); - auto output_tensor_flat = output_tensor->flat<tensorflow::tstring>(); - - absl::ReaderMutexLock lock(&sp->mu); - for (int i = 0; i < input_tensor_flat.size(); ++i) { - output_tensor_flat(i) = sp->processor.IdToPiece(input_tensor_flat(i)); - } - } -}; - -REGISTER_KERNEL_BUILDER(Name("SentencepieceIdToStringOp").Device(DEVICE_CPU), - SentencepieceIdToStringOp); -ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("SentencepieceIdToStringOp"); - -class SentencepieceStringToIdOp : public OpKernel { - public: - explicit SentencepieceStringToIdOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - SentencepieceResource* sp; - const Tensor& resource_tensor = ctx->input(0); - ResourceHandle resource_handle(resource_tensor.scalar<ResourceHandle>()()); - OP_REQUIRES_OK( - ctx, ctx->resource_manager()->Lookup<SentencepieceResource>( - resource_handle.container(), resource_handle.name(), &sp)); - core::ScopedUnref unref_me(sp); - - const Tensor& input_tensor = ctx->input(1); - const auto input_tensor_flat = input_tensor.flat<tensorflow::tstring>(); - Tensor* output_tensor; - OP_REQUIRES_OK( - ctx, ctx->allocate_output(0, input_tensor.shape(), &output_tensor)); - auto output_tensor_flat = output_tensor->flat<int32>(); - - absl::ReaderMutexLock lock(&sp->mu); - for (int i = 0; i < input_tensor_flat.size(); ++i) { - output_tensor_flat(i) = sp->processor.PieceToId(input_tensor_flat(i)); - } - } -}; - -REGISTER_KERNEL_BUILDER(Name("SentencepieceStringToIdOp").Device(DEVICE_CPU), - SentencepieceStringToIdOp); -ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("SentencepieceStringToIdOp"); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/spanning_tree_iterator.cc b/tensorflow_text/core/kernels/spanning_tree_iterator.cc deleted file mode 100644 index 1c859a543..000000000 --- a/tensorflow_text/core/kernels/spanning_tree_iterator.cc +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/spanning_tree_iterator.h" - -namespace tensorflow { -namespace text { - -SpanningTreeIterator::SpanningTreeIterator(bool forest) : forest_(forest) {} - -bool SpanningTreeIterator::HasCycle(const SourceList &sources) { - // Flags for whether each node has already been searched. - searched_.assign(sources.size(), false); - - // Flags for whether the search is currently visiting each node. - visiting_.assign(sources.size(), false); - - // Search upwards from each node to find cycles. - for (uint32 initial_node = 0; initial_node < sources.size(); ++initial_node) { - // Search upwards to try to find a cycle. - uint32 current_node = initial_node; - while (true) { - if (searched_[current_node]) break; // already searched - if (visiting_[current_node]) return true; // revisiting implies cycle - visiting_[current_node] = true; // mark as being currently visited - const uint32 source_node = sources[current_node]; - if (source_node == current_node) break; // self-loops are roots - current_node = source_node; // advance upwards - } - - // No cycle; search upwards again to update flags. - current_node = initial_node; - while (true) { - if (searched_[current_node]) break; // already searched - searched_[current_node] = true; - visiting_[current_node] = false; - const uint32 source_node = sources[current_node]; - if (source_node == current_node) break; // self-loops are roots - current_node = source_node; // advance upwards - } - } - - return false; -} - -uint32 SpanningTreeIterator::NumRoots(const SourceList &sources) { - uint32 num_roots = 0; - for (uint32 node = 0; node < sources.size(); ++node) { - num_roots += (node == sources[node]); - } - return num_roots; -} - -bool SpanningTreeIterator::NextSourceList(SourceList *sources) { - const uint32 num_nodes = sources->size(); - for (uint32 i = 0; i < num_nodes; ++i) { - const uint32 new_source = ++(*sources)[i]; - if (new_source < num_nodes) return true; // absorbed in this digit - (*sources)[i] = 0; // overflowed this digit, carry to next digit - } - return false; // overflowed the last digit -} - -bool SpanningTreeIterator::NextTree(SourceList *sources) { - // Iterate source lists, skipping non-trees. - while (NextSourceList(sources)) { - // Check the number of roots. - const uint32 num_roots = NumRoots(*sources); - if (forest_) { - if (num_roots == 0) continue; - } else { - if (num_roots != 1) continue; - } - - // Check for cycles. - if (HasCycle(*sources)) continue; - - // Acyclic and rooted, therefore tree. - return true; - } - return false; -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/spanning_tree_iterator.h b/tensorflow_text/core/kernels/spanning_tree_iterator.h index 68bc6f14a..34041d157 100644 --- a/tensorflow_text/core/kernels/spanning_tree_iterator.h +++ b/tensorflow_text/core/kernels/spanning_tree_iterator.h @@ -12,67 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_TEXT_CORE_KERNELS_SPANNING_TREE_ITERATOR_H_ -#define TENSORFLOW_TEXT_CORE_KERNELS_SPANNING_TREE_ITERATOR_H_ +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SPANNING_TREE_ITERATOR_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SPANNING_TREE_ITERATOR_H_ -#include <vector> +#include "tensorflow/core/kernels/text/spanning_tree_iterator.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { -namespace text { - -// A class that iterates over all possible spanning trees of a complete digraph. -// Thread-compatible. Useful for brute-force comparison tests. -// -// TODO(terrykoo): Try using Prufer sequences, which are more efficient to -// enumerate as there are no non-trees to filter out. -class SpanningTreeIterator { - public: - // An array that provides the source of the inbound arc for each node. Roots - // are represented as self-loops. - using SourceList = std::vector<uint32>; - - // Creates a spanning tree iterator. If |forest| is true, then this iterates - // over forests instead of trees (i.e., multiple roots are allowed). - explicit SpanningTreeIterator(bool forest); - - // Applies the |functor| to all spanning trees (or forests, if |forest_| is - // true) of a complete digraph containing |num_nodes| nodes. Each tree is - // passed to the |functor| as a SourceList. - template <class Functor> - void ForEachTree(uint32 num_nodes, Functor functor) { - // Conveniently, the all-zero vector represents a valid tree. - SourceList sources(num_nodes, 0); - do { - functor(sources); - } while (NextTree(&sources)); - } - - private: - // Returns true if the |sources| contains a cycle. - bool HasCycle(const SourceList &sources); - - // Returns the number of roots in the |sources|. - static uint32 NumRoots(const SourceList &sources); - - // Advances |sources| to the next source list, or returns false if there are - // no more source lists. - static bool NextSourceList(SourceList *sources); - - // Advances |sources| to the next tree (or forest, if |forest_| is true), or - // returns false if there are no more trees. - bool NextTree(SourceList *sources); - - // If true, iterate over spanning forests instead of spanning trees. - const bool forest_; - - // Workspaces used by the search in HasCycle(). - std::vector<bool> searched_; - std::vector<bool> visiting_; -}; - -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_TEXT_CORE_KERNELS_SPANNING_TREE_ITERATOR_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_SPANNING_TREE_ITERATOR_H_ diff --git a/tensorflow_text/core/kernels/spanning_tree_iterator_test.cc b/tensorflow_text/core/kernels/spanning_tree_iterator_test.cc deleted file mode 100644 index db3e4439e..000000000 --- a/tensorflow_text/core/kernels/spanning_tree_iterator_test.cc +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/spanning_tree_iterator.h" - -#include <gmock/gmock.h> -#include <gtest/gtest.h> -#include "tensorflow/core/platform/logging.h" - -namespace tensorflow { -namespace text { - -// Testing rig. When the bool parameter is true, iterates over spanning forests -// instead of spanning trees. -class SpanningTreeIteratorTest : public ::testing::TestWithParam<bool> { - protected: - using SourceList = SpanningTreeIterator::SourceList; - - // Returns |base|^|exponent|. Computes the value as an integer to avoid - // rounding issues. - static int Pow(int base, int exponent) { - double real_product = 1.0; - int product = 1; - for (int i = 0; i < exponent; ++i) { - product *= base; - real_product *= base; - } - CHECK_EQ(product, real_product) << "Overflow detected."; - return product; - } - - // Expects that the number of possible spanning trees for a complete digraph - // of |num_nodes| nodes is |expected_num_trees|. - void ExpectNumTrees(int num_nodes, int expected_num_trees) { - int actual_num_trees = 0; - iterator_.ForEachTree( - num_nodes, [&](const SourceList &sources) { ++actual_num_trees; }); - LOG(INFO) << "num_nodes=" << num_nodes - << " expected_num_trees=" << expected_num_trees - << " actual_num_trees=" << actual_num_trees; - EXPECT_EQ(expected_num_trees, actual_num_trees); - } - - // Expects that the set of possible spanning trees for a complete digraph of - // |num_nodes| nodes is |expected_trees|. - void ExpectTrees(int num_nodes, const std::set<SourceList> &expected_trees) { - std::set<SourceList> actual_trees; - iterator_.ForEachTree(num_nodes, [&](const SourceList &sources) { - CHECK(actual_trees.insert(sources).second); - }); - EXPECT_EQ(expected_trees, actual_trees); - } - - // Instance for tests. Shared across assertions in a test to exercise reuse. - SpanningTreeIterator iterator_{GetParam()}; -}; - -INSTANTIATE_TEST_SUITE_P(AllowForest, SpanningTreeIteratorTest, - ::testing::Bool()); - -TEST_P(SpanningTreeIteratorTest, NumberOfTrees) { - // According to Cayley's formula, the number of undirected spanning trees on a - // complete graph of n nodes is n^{n-2}: - // https://en.wikipedia.org/wiki/Cayley%27s_formula - // - // To count the number of directed spanning trees, note that each undirected - // spanning tree gives rise to n directed spanning trees: choose one of the n - // nodes as the root, and then orient arcs outwards. Therefore, the number of - // directed spanning trees on a complete digraph of n nodes is n^{n-1}. - // - // To count the number of directed spanning forests, consider undirected - // spanning trees on a complete graph of n+1 nodes. Arbitrarily select one - // node as the artificial root, orient arcs outwards, and then delete the - // artificial root and its outbound arcs. The result is a directed spanning - // forest on n nodes. Therefore, the number of directed spanning forests on a - // complete digraph of n nodes is (n+1)^{n-1}. - for (int num_nodes = 1; num_nodes <= 7; ++num_nodes) { - if (GetParam()) { // forest - ExpectNumTrees(num_nodes, Pow(num_nodes + 1, num_nodes - 1)); - } else { // tree - ExpectNumTrees(num_nodes, Pow(num_nodes, num_nodes - 1)); - } - } -} - -TEST_P(SpanningTreeIteratorTest, OneNodeDigraph) { ExpectTrees(1, {{0}}); } - -TEST_P(SpanningTreeIteratorTest, TwoNodeDigraph) { - if (GetParam()) { // forest - ExpectTrees(2, {{0, 0}, {0, 1}, {1, 1}}); // {0, 1} is two-root structure - } else { // tree - ExpectTrees(2, {{0, 0}, {1, 1}}); - } -} - -TEST_P(SpanningTreeIteratorTest, ThreeNodeDigraph) { - if (GetParam()) { // forest - ExpectTrees(3, {{0, 0, 0}, - {0, 0, 1}, - {0, 0, 2}, // 2-root - {0, 1, 0}, // 2-root - {0, 1, 1}, // 2-root - {0, 1, 2}, // 3-root - {0, 2, 0}, - {0, 2, 2}, // 2-root - {1, 1, 0}, - {1, 1, 1}, - {1, 1, 2}, // 2-root - {1, 2, 2}, - {2, 0, 2}, - {2, 1, 1}, - {2, 1, 2}, // 2-root - {2, 2, 2}}); - } else { // tree - ExpectTrees(3, {{0, 0, 0}, - {0, 0, 1}, - {0, 2, 0}, - {1, 1, 0}, - {1, 1, 1}, - {1, 2, 2}, - {2, 0, 2}, - {2, 1, 1}, - {2, 2, 2}}); - } -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/split_merge_tokenize_kernel.cc b/tensorflow_text/core/kernels/split_merge_tokenize_kernel.cc deleted file mode 100644 index 5491fab4d..000000000 --- a/tensorflow_text/core/kernels/split_merge_tokenize_kernel.cc +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <limits> -#include <memory> -#include <string> -#include <vector> - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "icu4c/source/common/unicode/uchar.h" -#include "icu4c/source/common/unicode/umachine.h" -#include "icu4c/source/common/unicode/utf8.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { -namespace text { - -namespace { - -// Returns the length (number of bytes) of the UTF8 code point starting at src, -// by reading only the byte from address src. -// -// The result is a number from the set {1, 2, 3, 4}. -int OneCharLen(const char* src) { - // On most platforms, char is unsigned by default, but iOS is an exception. - // The cast below makes sure we always interpret *src as an unsigned char. - return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4" - [(*(reinterpret_cast<const unsigned char*>(src)) & 0xFF) >> 4]; -} - -bool GetUTF8Chars(absl::string_view text, - std::vector<absl::string_view>* chars) { - const char* start = text.data(); - const char* end = text.data() + text.size(); - while (start < end) { - const int char_length = OneCharLen(start); - if (char_length <= 0) { - return false; - } - chars->emplace_back(start, char_length); - start += char_length; - } - return true; -} - -bool IsBreakChar(absl::string_view text) { - UChar32 c; - int position = 0; - U8_NEXT_OR_FFFD(text.data(), position, text.length(), c); - return u_isUWhiteSpace(c); -} - -Status TokenizeByLabel(const absl::string_view& text, - const Tensor& labels_tensor, - bool force_split_at_break_character, - std::vector<std::string>* tokens, - std::vector<int>* begin_offset, - std::vector<int>* end_offset, int* num_tokens) { - std::vector<absl::string_view> chars; - if (!GetUTF8Chars(text, &chars)) { - return Status(static_cast<::absl::StatusCode>( - absl::StatusCode::kInvalidArgument), - absl::StrCat("Input string is not utf8 valid: ", text)); - } - - if (chars.size() > labels_tensor.dim_size(0)) { - return Status(static_cast<::absl::StatusCode>( - absl::StatusCode::kInvalidArgument), - absl::StrCat("Number of labels ", labels_tensor.dim_size(0), - " is insufficient for text ", text)); - } - - const int split_label = 0; - bool last_character_is_break_character = false; - int start = 0; - bool has_new_token_generated_for_text = false; - const auto& labels = labels_tensor.unaligned_flat<int32>(); - for (int i = 0; i < chars.size(); ++i) { - const bool is_break_character = IsBreakChar(chars[i]); - if (!is_break_character) { - if (labels(i) == split_label || !has_new_token_generated_for_text || - (last_character_is_break_character && - force_split_at_break_character)) { - tokens->emplace_back(chars[i].data(), chars[i].length()); - begin_offset->push_back(start); - end_offset->push_back(start + chars[i].length()); - *num_tokens += 1; - has_new_token_generated_for_text = true; - } else { - tokens->back().append(chars[i].data(), chars[i].length()); - end_offset->back() = start + chars[i].length(); - } - } - - start += chars[i].length(); - last_character_is_break_character = is_break_character; - } - - return absl::OkStatus(); -} - -} // namespace - -class SplitMergeTokenizeWithOffsetsOp : public OpKernel { - public: - explicit SplitMergeTokenizeWithOffsetsOp(OpKernelConstruction* ctx) - : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("force_split_at_break_character", - &force_split_at_break_character_)); - } - - void Compute(OpKernelContext* ctx) override { - const Tensor* input_values; - OP_REQUIRES_OK(ctx, ctx->input("input_values", &input_values)); - - const Tensor* labels; - OP_REQUIRES_OK(ctx, ctx->input("labels", &labels)); - const Tensor* row_splits; - OP_REQUIRES_OK(ctx, ctx->input("row_splits", &row_splits)); - OP_REQUIRES(ctx, input_values->dim_size(0) == row_splits->dim_size(0) - 1, - errors::InvalidArgument("Expecting row_splits have ", - input_values->dim_size(0) + 1, - " elements, got ", - row_splits->dim_size(0))); - - std::vector<string> tokens; - std::vector<int> begin_offset; - std::vector<int> end_offset; - std::vector<int> output_row_splits(1, 0); - - // Iterate through all the values and tokenize them. - const auto& values_vec = input_values->flat<tstring>(); - const auto& row_splits_vec = row_splits->flat<int32>(); - for (int i = 0; i < values_vec.size(); ++i) { - // Tokenize into tokens and record the offset locations. - int num_tokens = 0; - OP_REQUIRES_OK( - ctx, TokenizeByLabel( - values_vec(i), - labels->Slice(row_splits_vec(i), row_splits_vec(i + 1)), - force_split_at_break_character_, &tokens, &begin_offset, - &end_offset, &num_tokens)); - - // Record the row splits. - output_row_splits.push_back(num_tokens + output_row_splits.back()); - } - - std::vector<int64> output_tokens_shape; - output_tokens_shape.push_back(tokens.size()); - - std::vector<int64> output_row_splits_shape; - output_row_splits_shape.push_back(output_row_splits.size()); - - Tensor* output_values; - OP_REQUIRES_OK(ctx, ctx->allocate_output("output_values", - TensorShape(output_tokens_shape), - &output_values)); - auto output_values_vec = output_values->vec<tstring>(); - - Tensor* output_row_splits_tensor; - OP_REQUIRES_OK(ctx, - ctx->allocate_output("output_row_splits", - TensorShape(output_row_splits_shape), - &output_row_splits_tensor)); - auto output_row_splits_vec = output_row_splits_tensor->vec<int64>(); - - Tensor* start_values; - OP_REQUIRES_OK(ctx, ctx->allocate_output("start_values", - TensorShape(output_tokens_shape), - &start_values)); - auto start_values_vec = start_values->vec<int64>(); - - Tensor* limit_values; - OP_REQUIRES_OK(ctx, ctx->allocate_output("limit_values", - TensorShape(output_tokens_shape), - &limit_values)); - auto limit_values_vec = limit_values->vec<int64>(); - - for (int i = 0; i < tokens.size(); ++i) { - output_values_vec(i) = tokens[i]; - } - - for (int i = 0; i < output_row_splits.size(); ++i) { - output_row_splits_vec(i) = output_row_splits[i]; - } - - for (int i = 0; i < begin_offset.size(); ++i) { - start_values_vec(i) = begin_offset[i]; - } - - for (int i = 0; i < end_offset.size(); ++i) { - limit_values_vec(i) = end_offset[i]; - } - } - - private: - bool force_split_at_break_character_; - - TF_DISALLOW_COPY_AND_ASSIGN(SplitMergeTokenizeWithOffsetsOp); -}; - -REGISTER_KERNEL_BUILDER( - Name("SplitMergeTokenizeWithOffsets").Device(DEVICE_CPU), - SplitMergeTokenizeWithOffsetsOp); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/string_vocab.cc b/tensorflow_text/core/kernels/string_vocab.cc deleted file mode 100644 index a2c239a93..000000000 --- a/tensorflow_text/core/kernels/string_vocab.cc +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/string_vocab.h" - -namespace tensorflow { -namespace text { - -StringVocab::StringVocab(const std::vector<std::string>& vocab) - : vocab_(vocab) { - index_map_.reserve(vocab.size()); - for (int i = 0; i < vocab.size(); ++i) { - index_map_[vocab_[i]] = i; - } -} - -LookupStatus StringVocab::Contains(absl::string_view key, bool* value) const { - *value = index_map_.contains(key); - return LookupStatus(); -} - -absl::optional<int> StringVocab::LookupId(absl::string_view key) const { - auto it = index_map_.find(key); - if (it == index_map_.end()) { - return absl::nullopt; - } else { - return it->second; - } -} - -// Returns the key of `vocab_id` or empty if `vocab_id` is not valid. -absl::optional<absl::string_view> StringVocab::LookupWord(int vocab_id) const { - if (vocab_id >= vocab_.size() || vocab_id < 0) { - return absl::nullopt; - } - return vocab_[vocab_id]; -} -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/string_vocab.h b/tensorflow_text/core/kernels/string_vocab.h index 4590f2775..d58daa772 100644 --- a/tensorflow_text/core/kernels/string_vocab.h +++ b/tensorflow_text/core/kernels/string_vocab.h @@ -15,34 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_STRING_VOCAB_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_STRING_VOCAB_H_ -#include <string> -#include <vector> - -#include "absl/container/flat_hash_map.h" -#include "tensorflow_text/core/kernels/wordpiece_tokenizer.h" - -namespace tensorflow { -namespace text { - -// An implementation of WordpieceVocab, used (1) to store the input vocabulary -// and (2) to call the original implementation of WordPiece tokenization to -// pre-compute the result for the suffix indicator string. -class StringVocab : public WordpieceVocab { - public: - explicit StringVocab(const std::vector<std::string>& vocab); - StringVocab(const StringVocab&) = delete; - StringVocab& operator=(const StringVocab&) = delete; - LookupStatus Contains(absl::string_view key, bool* value) const override; - absl::optional<int> LookupId(absl::string_view key) const; - // Returns the key of `vocab_id` or empty if `vocab_id` is not valid. - absl::optional<absl::string_view> LookupWord(int vocab_id) const; - int Size() const { return index_map_.size(); } - - private: - std::vector<std::string> vocab_; - absl::flat_hash_map<absl::string_view, int> index_map_; -}; -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/string_vocab.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_STRING_VOCAB_H_ diff --git a/tensorflow_text/core/kernels/text_kernels_test_util.cc b/tensorflow_text/core/kernels/text_kernels_test_util.cc deleted file mode 100644 index 15da35665..000000000 --- a/tensorflow_text/core/kernels/text_kernels_test_util.cc +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/text_kernels_test_util.h" - -using ::testing::MakeMatcher; -using ::testing::Matcher; -using ::testing::MatchResultListener; - -namespace tensorflow { -namespace text_kernels_test_util { - -bool TensorEqMatcher::MatchAndExplain( - Tensor actual, ::testing::MatchResultListener* listener) const { - std::string expect_values = expect_.SummarizeValue(expect_.NumElements()); - std::string actual_values = actual.SummarizeValue(actual.NumElements()); - if (expect_.dtype() != actual.dtype() || expect_.shape() != actual.shape() || - expect_values != actual_values) { - *listener << "\n dtype=" << DataTypeString(actual.dtype()); - *listener << "\n shape=" << actual.shape().DebugString(); - *listener << "\n values=" << actual_values; - return false; - } - return true; -} - -void TensorEqMatcher::DescribeTo(::std::ostream* gmock_os) const { - *gmock_os << "dtype=" << DataTypeString(expect_.dtype()) - << "\n shape=" << expect_.shape().DebugString() - << "\n values=" - << expect_.SummarizeValue(expect_.NumElements()); -} - -void TensorEqMatcher::DescribeNegationTo(::std::ostream* gmock_os) const { - *gmock_os << "is not equal to " << expect_.DebugString(); -} - -bool TensorHasShapeMatcher::MatchAndExplain( - Tensor actual, ::testing::MatchResultListener* listener) const { - if (expect_ != actual.shape()) { - *listener << "\n shape=" << actual.shape().DebugString(); - return false; - } - return true; -} - -void TensorHasShapeMatcher::DescribeTo(::std::ostream* gmock_os) const { - *gmock_os << "shape=" << expect_.DebugString(); -} - -void TensorHasShapeMatcher::DescribeNegationTo(::std::ostream* gmock_os) const { - *gmock_os << "shape!=" << expect_.DebugString(); -} - -Matcher<Tensor> TensorHasShape(const TensorShape& shape) { - // MakeMatcher takes ownership of the TensorHasShapeMatcher. - return MakeMatcher(new TensorHasShapeMatcher(shape)); -} - -} // namespace text_kernels_test_util -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/text_kernels_test_util.h b/tensorflow_text/core/kernels/text_kernels_test_util.h index 9762b385f..c992dc278 100644 --- a/tensorflow_text/core/kernels/text_kernels_test_util.h +++ b/tensorflow_text/core/kernels/text_kernels_test_util.h @@ -12,112 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -// GMock matchers for testing text kernels: -// TensorHasShapeAndValues<DTYPE>({dim1, ..., dimN}, {v1, v2, ..., vN}); -// VectorEq<DTYPE>({v1, v2, ..., vN}); -// MatrixEq<DTYPE>({{v1_1, ..., v1_M}, ..., {vN_1, ..., vN_M}}); -// TensorHasShape({dim1, ..., dimN}); +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_TEXT_KERNELS_TEST_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_TEXT_KERNELS_TEST_UTIL_H_ -#ifndef TENSORFLOW_TEXT_CORE_KERNELS_TEXT_KERNELS_TEST_UTIL_H_ -#define TENSORFLOW_TEXT_CORE_KERNELS_TEXT_KERNELS_TEST_UTIL_H_ +#include "tensorflow/core/kernels/text/text_kernels_test_util.h" -#include <gmock/gmock.h> -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_testutil.h" - -namespace tensorflow { -namespace text_kernels_test_util { - -// GMock MatcherInterface for testing tensor equality. -class TensorEqMatcher : public ::testing::MatcherInterface<Tensor> { - public: - explicit TensorEqMatcher(const Tensor& expect) : expect_(expect) {} - bool MatchAndExplain(Tensor actual, - ::testing::MatchResultListener* listener) const override; - void DescribeTo(::std::ostream* gmock_os) const override; - void DescribeNegationTo(::std::ostream* gmock_os) const override; - - private: - Tensor expect_; -}; - -// GMock MatcherInterface for testing tensor shapes. -class TensorHasShapeMatcher : public ::testing::MatcherInterface<Tensor> { - public: - explicit TensorHasShapeMatcher(const TensorShape& expect) : expect_(expect) {} - bool MatchAndExplain(Tensor actual, - ::testing::MatchResultListener* listener) const override; - void DescribeTo(::std::ostream* gmock_os) const override; - void DescribeNegationTo(::std::ostream* gmock_os) const override; - - private: - TensorShape expect_; -}; - -// Returns a gmock matcher that checks whether a given tensor has the specified -// dtype, values, and shape. dtype is specified using the template parameter. -// values are specified as a flattened vector. -// Example: -// EXPECT_THAT(*GetOutput(0), -// TensorHasShapeAndValues<int64>({3, 2}, {1, 2, 3, 4, 5, 6}); -template <typename DTYPE> -::testing::Matcher<Tensor> TensorHasShapeAndValues( - const TensorShape& shape, const std::vector<DTYPE>& values) { - Tensor expect = test::AsTensor<DTYPE>(values, shape); - // MakeMatcher takes ownership of the TensorEqMatcher. - return ::testing::MakeMatcher(new TensorEqMatcher(expect)); -} - -// Returns a gmock matcher that checks whether a given tensor is a 1-D tensor -// with the specified dtype and values. dtype is specified using the template -// parameter. -// Example: -// EXPECT_THAT(*GetOutput(0), -// VectorEq<int64>({1, 2, 3, 4, 5, 6}); -template <typename DTYPE> -::testing::Matcher<Tensor> VectorEq(const std::vector<DTYPE>& values) { - int64_t nvals = values.size(); - Tensor expect = test::AsTensor<DTYPE>(values, {nvals}); - // MakeMatcher takes ownership of the TensorEqMatcher. - return ::testing::MakeMatcher(new TensorEqMatcher(expect)); -} - -// Returns a gmock matcher that checks whether a given tensor is a 2-D tensor -// with the specified dtype and values. dtype is specified using the template -// parameter. values are specified as a nested vector. All rows of the values -// vector must have the same length. The values vector may not be empty, -// since we can't infer the number of columns for an empty matrix; to test -// empty matrices, use the more general TensorHasShapeAndValues() instead. -// Example: -// EXPECT_THAT(*GetOutput(0), -// MatrixEq<int64>({{1, 2, 3}, {4, 5, 6}}); -template <typename DTYPE> -::testing::Matcher<Tensor> MatrixEq( - const std::vector<std::vector<DTYPE>>& values) { - int64_t nrows = values.size(); - CHECK_GT(nrows, 0) // Crash OK - << "Invalid use of MatrixEq: to test empty matrices, use " - << "TensorHasShapeAndValues<dtype>{{0, ndims}, {}} instead."; - int64_t ncols = values[0].size(); - std::vector<DTYPE> flat; - for (const auto& row : values) { - CHECK_EQ(ncols, row.size()) // Crash OK - << "Invalid use of MatrixEq: all rows must have equal length"; - flat.insert(flat.end(), row.begin(), row.end()); - } - Tensor expect = test::AsTensor<DTYPE>(flat, TensorShape({nrows, ncols})); - // MakeMatcher takes ownership of the TensorEqMatcher. - return ::testing::MakeMatcher(new TensorEqMatcher(expect)); -} - -// Returns a gmock matcher that checks whether a given tensor has a specified -// shape. -// Example: -// EXPECT_THAT(*GetOutput(0), TensorHasShape({2, 8}); -::testing::Matcher<Tensor> TensorHasShape(const TensorShape& shape); - -} // namespace text_kernels_test_util -} // namespace tensorflow - -#endif // TENSORFLOW_TEXT_CORE_KERNELS_TEXT_KERNELS_TEST_UTIL_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_TEXT_KERNELS_TEST_UTIL_H_ diff --git a/tensorflow_text/core/kernels/tokenizer_from_logits_kernel.cc b/tensorflow_text/core/kernels/tokenizer_from_logits_kernel.cc deleted file mode 100644 index 39262dd94..000000000 --- a/tensorflow_text/core/kernels/tokenizer_from_logits_kernel.cc +++ /dev/null @@ -1,235 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <limits> -#include <memory> -#include <string> -#include <vector> - -#include "absl/strings/str_cat.h" -#include "icu4c/source/common/unicode/uchar.h" -#include "icu4c/source/common/unicode/umachine.h" -#include "icu4c/source/common/unicode/utf8.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { -namespace text { - -namespace { - -// Returns the length (number of bytes) of the UTF8 code point starting at src, -// by reading only the byte from address src. -// -// The result is a number from the set {1, 2, 3, 4}. -int OneCharLen(const char* src) { - // On most platforms, char is unsigned by default, but iOS is an exception. - // The cast below makes sure we always interpret *src as an unsigned char. - return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4" - [(*(reinterpret_cast<const unsigned char*>(src)) & 0xFF) >> 4]; -} - -bool GetUTF8Chars(absl::string_view text, - std::vector<absl::string_view>* chars) { - const char* start = text.data(); - const char* end = text.data() + text.size(); - while (start < end) { - const int char_length = OneCharLen(start); - if (char_length <= 0) { - return false; - } - chars->emplace_back(start, char_length); - start += char_length; - } - return true; -} - -bool IsBreakChar(absl::string_view text) { - UChar32 c; - int position = 0; - U8_NEXT_OR_FFFD(text.data(), position, text.length(), c); - return u_isUWhiteSpace(c); -} - -// Tokenizes text, the input string #(batch_index). Knowing the batch_index -// allows us to retrieve the corresponding data from logits. I.e., the logits -// for the i-th character from text are logits(batch_index, i, 0) (for the -// "split" action) and logits(batch_index, i, 1) (for the "merge" action). -Status TokenizeByLogits(const absl::string_view& text, - const TTypes<const float, 3>::Tensor& logits, - int batch_index, - bool force_split_at_break_character, - std::vector<std::string>* tokens, - std::vector<int>* begin_offset, - std::vector<int>* end_offset, int* num_tokens) { - std::vector<absl::string_view> chars; - if (!GetUTF8Chars(text, &chars)) { - return Status( - static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument), - absl::StrCat("Input string is not utf8 valid: ", text)); - } - - if (chars.size() > logits.dimension(1)) { - return Status( - static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument), - absl::StrCat("Number of logits, ", logits.dimension(1), - ", is insufficient for text \"", text, "\"")); - } - - bool last_character_is_break_character = false; - int start = 0; - bool has_new_token_generated_for_text = false; - for (int i = 0; i < chars.size(); ++i) { - const bool is_break_character = IsBreakChar(chars[i]); - if (!is_break_character) { - const float logit_split = logits(batch_index, i, 0); - const float logit_merge = logits(batch_index, i, 1); - if ((logit_split > logit_merge) || - !has_new_token_generated_for_text || - (last_character_is_break_character && - force_split_at_break_character)) { - tokens->emplace_back(chars[i].data(), chars[i].length()); - begin_offset->push_back(start); - end_offset->push_back(start + chars[i].length()); - *num_tokens += 1; - has_new_token_generated_for_text = true; - } else { - tokens->back().append(chars[i].data(), chars[i].length()); - end_offset->back() = start + chars[i].length(); - } - } - - start += chars[i].length(); - last_character_is_break_character = is_break_character; - } - - return absl::OkStatus(); -} - -} // namespace - -class TokenizerFromLogitsOp : public OpKernel { - public: - explicit TokenizerFromLogitsOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - const Tensor* strings; - OP_REQUIRES_OK(ctx, ctx->input("strings", &strings)); - const Tensor* logits; - OP_REQUIRES_OK(ctx, ctx->input("logits", &logits)); - OP_REQUIRES(ctx, strings->dim_size(0) == logits->dim_size(0), - errors::InvalidArgument("Expecting logits to have ", - strings->dim_size(0), - " rows, got ", - logits->dim_size(0))); - const Tensor* force_split_at_break_character; - OP_REQUIRES_OK(ctx, ctx->input("force_split_at_break_character", - &force_split_at_break_character)); - const bool force_split_at_break_character_bool = - force_split_at_break_character->scalar<bool>()(); - - std::vector<string> tokens; - std::vector<int> begin_offset; - std::vector<int> end_offset; - std::vector<int> output_row_splits(1, 0); - - // Tensor to access values from logits. - const TTypes<const float, 3>::Tensor logits_tensor = - logits->tensor<float, 3>(); - - // Iterate through all the values and tokenize them. - const auto& strings_vec = strings->flat<tstring>(); - OP_REQUIRES(ctx, logits_tensor.dimension(0) >= strings_vec.size(), - errors::Internal("Bad logits dimension #0: ", - logits_tensor.dimension(0), " < ", - strings_vec.size())); - // Dimension #1 of logits will be checked inside TokenizeByLogits. - OP_REQUIRES(ctx, logits_tensor.dimension(2) == 2, - errors::Internal("Bad logits dimension #2: ", - logits_tensor.dimension(2), " != 2")); - for (int i = 0; i < strings_vec.size(); ++i) { - // Tokenize into tokens and record the offset locations. - int num_tokens = 0; - OP_REQUIRES_OK( - ctx, TokenizeByLogits( - strings_vec(i), - logits_tensor, i, - force_split_at_break_character_bool, - &tokens, &begin_offset, &end_offset, &num_tokens)); - - // Record the row splits. - output_row_splits.push_back(num_tokens + output_row_splits.back()); - } - - std::vector<int64> output_tokens_shape; - output_tokens_shape.push_back(tokens.size()); - - std::vector<int64> output_row_splits_shape; - output_row_splits_shape.push_back(output_row_splits.size()); - - Tensor* output_values; - OP_REQUIRES_OK(ctx, ctx->allocate_output("output_values", - TensorShape(output_tokens_shape), - &output_values)); - auto output_values_vec = output_values->vec<tstring>(); - - Tensor* output_row_splits_tensor; - OP_REQUIRES_OK(ctx, - ctx->allocate_output("row_splits", - TensorShape(output_row_splits_shape), - &output_row_splits_tensor)); - auto output_row_splits_vec = output_row_splits_tensor->vec<int64>(); - - Tensor* start_values; - OP_REQUIRES_OK(ctx, ctx->allocate_output("start_values", - TensorShape(output_tokens_shape), - &start_values)); - auto start_values_vec = start_values->vec<int64>(); - - Tensor* limit_values; - OP_REQUIRES_OK(ctx, ctx->allocate_output("limit_values", - TensorShape(output_tokens_shape), - &limit_values)); - auto limit_values_vec = limit_values->vec<int64>(); - - for (int i = 0; i < tokens.size(); ++i) { - output_values_vec(i) = tokens[i]; - } - - for (int i = 0; i < output_row_splits.size(); ++i) { - output_row_splits_vec(i) = output_row_splits[i]; - } - - for (int i = 0; i < begin_offset.size(); ++i) { - start_values_vec(i) = begin_offset[i]; - } - - for (int i = 0; i < end_offset.size(); ++i) { - limit_values_vec(i) = end_offset[i]; - } - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(TokenizerFromLogitsOp); -}; - -REGISTER_KERNEL_BUILDER( - Name("TokenizerFromLogits").Device(DEVICE_CPU), - TokenizerFromLogitsOp); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/trimmer.h b/tensorflow_text/core/kernels/trimmer.h index f2781fc93..7c4f463cb 100644 --- a/tensorflow_text/core/kernels/trimmer.h +++ b/tensorflow_text/core/kernels/trimmer.h @@ -15,78 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_TRIMMER_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_TRIMMER_H_ -#include <vector> - -#include "absl/types/span.h" - -namespace tensorflow { -namespace text { - -using Mask = std::vector<bool>; -template <typename T> -using Values = std::vector<T>; -template <typename T> -using ValuesSpan = absl::Span<T>; -template <typename Tsplits> -using RowSplits = std::vector<Tsplits>; -template <typename Tsplits> -using RowSplitsSpan = absl::Span<Tsplits>; - -template <typename T> -class Trimmer { - using ValuesT = Values<T>; - - public: - // Generates masks for a single batch of values. - virtual std::vector<Mask> GenerateMasks( - const std::vector<ValuesT>& values) const = 0; - - // Trims a single batch of values. - virtual void Trim(std::vector<ValuesT>* values) const = 0; - - virtual ~Trimmer() = default; -}; - -template <typename T, typename Tsplits> -class BatchTrimmer { - using Values_ = Values<T>; - using ValuesSpan_ = ValuesSpan<T>; - using RowSplits_ = RowSplits<Tsplits>; - using RowSplitsSpan_ = RowSplitsSpan<Tsplits>; - - public: - // Generates masks for a batch of value row splits. - // - // Args: - // row_splits: Row splits of the values in the shape [batch, (num values)] - // - // Returns: - // The returned value is a flattened list of mask values which can be split - // into batches using the same input row splits. - virtual std::vector<Mask> GenerateMasksBatch( - const std::vector<RowSplits_>& row_splits) const = 0; - virtual std::vector<Mask> GenerateMasksBatch( - const std::vector<RowSplitsSpan_>& row_splits) const = 0; - - // Trims a batch of values given their flattened values and row splits. - // - // Args: - // flat_values: Flattened values in shape [batch, (num values)] - // row_splits: Row splits of the values in the shape [batch, (num values)] - // - // Returns: - // The returned values are the flattened trimmed values and new row splits. - virtual std::pair<std::vector<Values_>, std::vector<RowSplits_>> TrimBatch( - const std::vector<Values_>& flat_values, - const std::vector<RowSplits_>& row_splits) const = 0; - virtual std::pair<std::vector<Values_>, std::vector<RowSplits_>> TrimBatch( - const std::vector<ValuesSpan_>& flat_values, - const std::vector<RowSplitsSpan_>& row_splits) const = 0; - - virtual ~BatchTrimmer() = default; -}; - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/trimmer.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_TRIMMER_H_ diff --git a/tensorflow_text/core/kernels/unicode_script_tokenize_kernel.cc b/tensorflow_text/core/kernels/unicode_script_tokenize_kernel.cc deleted file mode 100644 index 6217f1bd5..000000000 --- a/tensorflow_text/core/kernels/unicode_script_tokenize_kernel.cc +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <string.h> - -#include <vector> - -#include "icu4c/source/common/unicode/errorcode.h" -#include "icu4c/source/common/unicode/uchar.h" -#include "icu4c/source/common/unicode/uscript.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/lookup_interface.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace text { - -template <typename SPLITS_TYPE> -class UnicodeScriptTokenizeWithOffsetsOp : public OpKernel { - public: - explicit UnicodeScriptTokenizeWithOffsetsOp(OpKernelConstruction* ctx) - : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_whitespace", &keep_whitespace_)); - } - - /** - * Breaks a series of codepoints into individual groups based on the script - * code as defined by ICU. - * - * We gain a dimension while tokenizing since a series of integer codepoints - * is tokenized into different codepoint groups. - * - * This accepts two input tensors: a rank 1 tensor of codepoint values and - * a single rank 1 tensor of splits which determine where each string begins - * and ends from the provided codepoints. - */ - void Compute(OpKernelContext* context) override { - // Get inputs - const Tensor& input_values_tensor = context->input(0); - const auto input_values_flat = input_values_tensor.flat<int32>(); - const Tensor& input_splits_tensor = context->input(1); - const auto input_splits_flat = input_splits_tensor.flat<SPLITS_TYPE>(); - - // Since we limit to a 2-D input (flat_values of rank 1 and a single splits - // tensor), our output dimension will always be 3-D (flat_values of rank 1 - // with two splits - inner for the tokenized values and the outer for those - // grouped by the original strings). - // A few things to note: - // 1) The values and inner splits of the tokenized strings have an unknown - // length, as well as the offsets, so we allocate them at the end. - // 2) The outer splits of the tokenized strings matches that of the offset - // splits. Thus, we will only return one set and use it for all of them. - // 3) The outer splits shape will match the original input_splits. - Tensor* output_outer_splits_tensor; - OP_REQUIRES_OK(context, - context->allocate_output("output_outer_splits", - input_splits_tensor.shape(), - &output_outer_splits_tensor)); - auto output_outer_splits_flat = - output_outer_splits_tensor->flat<SPLITS_TYPE>(); - - std::vector<int32> output_values; - std::vector<SPLITS_TYPE> output_values_inner_splits; - std::vector<int64> output_offset_starts; - std::vector<int64> output_offset_limits; - - // Loop over the codepoints (a split at a time) and create splits of tokens. - icu::ErrorCode status; - for (int splits_idx = 0; splits_idx < input_splits_flat.size() - 1; - splits_idx++) { - output_outer_splits_flat(splits_idx) = output_offset_starts.size(); - UScriptCode prev_script = USCRIPT_INVALID_CODE; - bool token_has_start_set = false; - int32 curr_skipped_spaces = 0; // Used when computing the end of a token - const int curr_word_start_idx = input_splits_flat(splits_idx); - bool was_space = false; - for (int values_idx = curr_word_start_idx; - values_idx < input_splits_flat(splits_idx + 1); values_idx++) { - const int32 input_value = input_values_flat(values_idx); - const bool is_space = u_isUWhiteSpace(input_value); - UScriptCode script = uscript_getScript(input_value, status); - // Split these failures out as if they are a different code and ignore - // the error. - if (status.isFailure()) { - status.reset(); - script = USCRIPT_INVALID_CODE; - } - // Split out a new token if the unicode script changes from the - // previous token. - if (script != prev_script || - (keep_whitespace_ && is_space != was_space)) { - if (token_has_start_set) { - output_offset_limits.push_back(values_idx - curr_word_start_idx - - curr_skipped_spaces); - } - prev_script = script; - token_has_start_set = false; - } - // Only copy characters other than whitespace. Because of this, also do - // not start new tokens until a character other than a space is reached. - if (!is_space || keep_whitespace_) { - if (!token_has_start_set) { - // Set token start offset relative to current string. - output_offset_starts.push_back(values_idx - curr_word_start_idx); - // Set split to indicate start of a new token. - output_values_inner_splits.push_back(output_values.size()); - token_has_start_set = true; - } - output_values.push_back(input_value); - } - if (!keep_whitespace_) { - if (is_space) { - curr_skipped_spaces++; - } else { - curr_skipped_spaces = 0; - } - } - was_space = is_space; - } - // Looping through the codepoints for current tokens complete. Now set the - // last limit of out last token (if we found a start earlier). - if (token_has_start_set) { - output_offset_limits.push_back(input_splits_flat(splits_idx + 1) - - curr_word_start_idx - - curr_skipped_spaces); - } - } - // Now set the closing value of our splits. - output_outer_splits_flat(input_splits_flat.size() - 1) = - output_offset_starts.size(); - output_values_inner_splits.push_back(output_values.size()); - -// Allocate output & fill output tensors. -#define DECLARE_ALLOCATE_AND_FILL_OUTPUT_TENSOR(name, dtype) \ - int64 name##_size = name.size(); \ - Tensor* name##_tensor = nullptr; \ - OP_REQUIRES_OK(context, \ - context->allocate_output(#name, TensorShape({name##_size}), \ - &name##_tensor)); \ - auto name##_data = name##_tensor->flat<dtype>().data(); \ - /* For empty outputs, the data pointer might be null. */ \ - if (name##_size > 0) { \ - memcpy(name##_data, name.data(), name##_size * sizeof(dtype)); \ - } \ - do { \ - } while (false) - - DECLARE_ALLOCATE_AND_FILL_OUTPUT_TENSOR(output_values, int32); - DECLARE_ALLOCATE_AND_FILL_OUTPUT_TENSOR(output_values_inner_splits, - SPLITS_TYPE); - DECLARE_ALLOCATE_AND_FILL_OUTPUT_TENSOR(output_offset_starts, int64); - DECLARE_ALLOCATE_AND_FILL_OUTPUT_TENSOR(output_offset_limits, int64); - -#undef DECLARE_ALLOCATE_AND_FILL_OUTPUT_TENSOR - } - - private: - bool keep_whitespace_; - - TF_DISALLOW_COPY_AND_ASSIGN(UnicodeScriptTokenizeWithOffsetsOp); -}; - -REGISTER_KERNEL_BUILDER(Name("UnicodeScriptTokenizeWithOffsets") - .Device(DEVICE_CPU) - .TypeConstraint<int32>("Tsplits"), - UnicodeScriptTokenizeWithOffsetsOp<int32>); -REGISTER_KERNEL_BUILDER(Name("UnicodeScriptTokenizeWithOffsets") - .Device(DEVICE_CPU) - .TypeConstraint<int64>("Tsplits"), - UnicodeScriptTokenizeWithOffsetsOp<int64>); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/unicode_script_tokenize_kernel_test.cc b/tensorflow_text/core/kernels/unicode_script_tokenize_kernel_test.cc deleted file mode 100644 index ebd84b10f..000000000 --- a/tensorflow_text/core/kernels/unicode_script_tokenize_kernel_test.cc +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include <gmock/gmock.h> -#include <gtest/gtest.h> -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow_text/core/kernels/text_kernels_test_util.h" - -namespace tensorflow { -namespace text { - -using tensorflow::FakeInput; -using tensorflow::NodeDefBuilder; -using tensorflow::Status; -using tensorflow::TensorShape; -using tensorflow::text_kernels_test_util::VectorEq; - -class UnicodeScriptTokenizeWithOffsetsKernelTest - : public tensorflow::OpsTestBase { - public: - void MakeOp() { - TF_ASSERT_OK(NodeDefBuilder("tested_op", "UnicodeScriptTokenizeWithOffsets") - .Input(FakeInput()) - .Input(FakeInput()) - .Finalize(node_def())); - TF_ASSERT_OK(InitOp()); - } -}; - -TEST_F(UnicodeScriptTokenizeWithOffsetsKernelTest, Test) { - MakeOp(); - AddInputFromArray<int32_t>(TensorShape({6}), {111, 112, 32, 116, 117, 118}); - AddInputFromArray<int64_t>(TensorShape({3}), {0, 4, 6}); - TF_ASSERT_OK(RunOpKernel()); - - std::vector<int32_t> expected_values({111, 112, 116, 117, 118}); - std::vector<int64_t> expected_values_inner_splits({0, 2, 3, 5}); - std::vector<int64_t> expected_offset_starts({0, 3, 0}); - std::vector<int64_t> expected_offset_limits({2, 4, 2}); - std::vector<int64_t> output_outer_splits({0, 2, 3}); - EXPECT_THAT(*GetOutput(0), VectorEq(expected_values)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_values_inner_splits)); - EXPECT_THAT(*GetOutput(2), VectorEq(expected_offset_starts)); - EXPECT_THAT(*GetOutput(3), VectorEq(expected_offset_limits)); - EXPECT_THAT(*GetOutput(4), VectorEq(output_outer_splits)); -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/utf8_binarize.cc b/tensorflow_text/core/kernels/utf8_binarize.cc deleted file mode 100644 index 2bba0de80..000000000 --- a/tensorflow_text/core/kernels/utf8_binarize.cc +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/utf8_binarize.h" -#include <algorithm> -#include <cassert> - -#include "absl/strings/string_view.h" -#include "icu4c/source/common/unicode/utf8.h" - -namespace tensorflow { -namespace text { - -void Utf8Binarize( - absl::string_view input, int word_length, int bits_per_char, - int replacement, /* out */ absl::Span<float> result) { - assert(result.size() == word_length * bits_per_char); - - const int input_size = input.size(); - int string_pos = 0; - int chars = 0; - int result_pos = 0; - while (string_pos < input_size && chars < word_length) { - UChar32 chr; - U8_NEXT(input, string_pos, input_size, chr); - if (chr < 0) { - // Decoding failure. - chr = replacement; - } - int bits = bits_per_char; - while (bits-- != 0) { - result[result_pos++] = (chr & 1) == 1 ? 1.0f : 0.0f; - chr >>= 1; - } - ++chars; - } - - std::fill(result.begin() + result_pos, result.end(), 0.0f); -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/utf8_binarize.h b/tensorflow_text/core/kernels/utf8_binarize.h index 908cf006b..a6e630ded 100644 --- a/tensorflow_text/core/kernels/utf8_binarize.h +++ b/tensorflow_text/core/kernels/utf8_binarize.h @@ -15,21 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_UTF8_BINARIZE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_UTF8_BINARIZE_H_ -#include "absl/strings/string_view.h" -#include "absl/types/span.h" - -namespace tensorflow { -namespace text { - -// Stores low-endian floating-point bitwise representations of Unicode code -// points of `input` in `result` (`result.size()` is required to be exactly -// `word_length * bits_per_char` - output is padded / truncated accordingly). -// Replacements (for invalid UTF sequences) are represented by the -// `bits_per_char` lowest bits of `replacement`. -void Utf8Binarize(absl::string_view input, int word_length, int bits_per_char, - int replacement, /* out */ absl::Span<float> result); - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/utf8_binarize.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_UTF8_BINARIZE_H_ diff --git a/tensorflow_text/core/kernels/utf8_binarize_kernel.h b/tensorflow_text/core/kernels/utf8_binarize_kernel.h index 3dfdcaad5..7d3e8847f 100644 --- a/tensorflow_text/core/kernels/utf8_binarize_kernel.h +++ b/tensorflow_text/core/kernels/utf8_binarize_kernel.h @@ -15,18 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_UTF8_BINARIZE_KERNEL_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_UTF8_BINARIZE_KERNEL_H_ -#include "tensorflow/lite/kernels/shim/tf_op_shim.h" -#include "tensorflow_text/core/kernels/utf8_binarize_kernel_template.h" - -namespace tensorflow { -namespace text { - -class Utf8BinarizeOpKernel : public tflite::shim::TfOpKernel<Utf8BinarizeOp> { - public: - using TfOpKernel::TfOpKernel; -}; - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/utf8_binarize_kernel.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_UTF8_BINARIZE_KERNEL_H_ diff --git a/tensorflow_text/core/kernels/utf8_binarize_kernel_template.h b/tensorflow_text/core/kernels/utf8_binarize_kernel_template.h index c8304ea6f..c4e921264 100644 --- a/tensorflow_text/core/kernels/utf8_binarize_kernel_template.h +++ b/tensorflow_text/core/kernels/utf8_binarize_kernel_template.h @@ -15,170 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_UTF8_BINARIZE_KERNEL_TEMPLATE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_UTF8_BINARIZE_KERNEL_TEMPLATE_H_ -#include <cstdint> -#include <vector> - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/core/platform/tstring.h" -#include "tensorflow/lite/kernels/shim/op_kernel.h" -#include "tensorflow/lite/kernels/shim/shape.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow_text/core/kernels/utf8_binarize.h" - -namespace tensorflow { -namespace text { - -template <tflite::shim::Runtime Rt> -class Utf8BinarizeOp : public tflite::shim::OpKernelShim<Utf8BinarizeOp, Rt> { - private: - enum Inputs { kInputTokens = 0 }; - enum Outputs { kOutputBinarizations = 0 }; - - using typename tflite::shim::OpKernelShim<Utf8BinarizeOp, Rt>::InitContext; - using typename tflite::shim::OpKernelShim<Utf8BinarizeOp, Rt>::InvokeContext; - using typename tflite::shim::OpKernelShim<Utf8BinarizeOp, - Rt>::ShapeInferenceContext; - - public: - Utf8BinarizeOp() = default; - static constexpr char kOpName[] = "TFText>Utf8Binarize"; - static constexpr char kDoc[] = R"doc( - Decode a UTF-8 string into Unicode code points - and return their bitwise little-endian representations - (see the [RetVec paper](https://arxiv.org/abs/2302.09207)). - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector<std::string> Attrs(); - - // Inputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector<std::string> Inputs(); - - // Outputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector<std::string> Outputs(); - - // Initializes the op - absl::Status Init(InitContext* context); - - // Runs the operation - absl::Status Invoke(InvokeContext* context); - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* c); - - private: - inline static constexpr absl::string_view kMaxCharsAttr = "word_length"; - inline static constexpr absl::string_view kBitsPerCharAttr = "bits_per_char"; - inline static constexpr absl::string_view kReplacementCharAttr = - "replacement_char"; - - int64_t word_length_; - int64_t bits_per_char_; - int64_t replacement_char_; -}; - -template <tflite::shim::Runtime Rt> -std::vector<std::string> Utf8BinarizeOp<Rt>::Attrs() { - return {absl::StrCat(kMaxCharsAttr, ": int"), - absl::StrCat(kBitsPerCharAttr, ": int"), - absl::StrCat(kReplacementCharAttr, ": int")}; -} - -template <tflite::shim::Runtime Rt> -std::vector<std::string> Utf8BinarizeOp<Rt>::Inputs() { - return {"input_tokens: string"}; -} - -template <tflite::shim::Runtime Rt> -std::vector<std::string> Utf8BinarizeOp<Rt>::Outputs() { - return {"output_binarizations: float"}; -} - -template <tflite::shim::Runtime Rt> -absl::Status Utf8BinarizeOp<Rt>::Init(InitContext* context) { - // Attrs - SH_RETURN_IF_ERROR( - context->GetAttr(std::string(kMaxCharsAttr), &word_length_)); - SH_RETURN_IF_ERROR( - context->GetAttr(std::string(kBitsPerCharAttr), &bits_per_char_)); - SH_RETURN_IF_ERROR( - context->GetAttr(std::string(kReplacementCharAttr), &replacement_char_)); - - return absl::OkStatus(); -} - -template <tflite::shim::Runtime Rt> -absl::Status Utf8BinarizeOp<Rt>::ShapeInference(ShapeInferenceContext* c) { - using tflite::shim::Shape; - const auto input_tokens_shape_status = c->GetInputShape(kInputTokens); - if (!input_tokens_shape_status.ok()) { - return input_tokens_shape_status.status(); - } - const Shape& input_tokens_shape = *input_tokens_shape_status; - - const auto rank_1_shape = Shape({Shape::kUnknownDim}); - if (!input_tokens_shape.Compatible(rank_1_shape)) { - return absl::FailedPreconditionError( - absl::StrCat("Shape must be rank 1: ", input_tokens_shape.ToString())); - } - - int64_t word_length; - SH_RETURN_IF_ERROR( - c->GetAttr(std::string(kMaxCharsAttr), &word_length)); - int64_t bits_per_char; - SH_RETURN_IF_ERROR(c->GetAttr(std::string(kBitsPerCharAttr), &bits_per_char)); - - const int num_tokens = input_tokens_shape.Dim(0); - const int bits_per_token = word_length * bits_per_char; - const Shape output_shape{num_tokens, bits_per_token}; - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputBinarizations, output_shape)); - - return absl::OkStatus(); -} - -template <tflite::shim::Runtime Rt> -absl::Status Utf8BinarizeOp<Rt>::Invoke(InvokeContext* context) { - // Attrs - const int word_length = word_length_; - const int bits_per_char = bits_per_char_; - const int replacement_char = replacement_char_; - const int bits_per_token = word_length * bits_per_char; - - // Inputs - const auto tokens_statusor = context->GetInput(kInputTokens); - if (!tokens_statusor.ok()) { - return tokens_statusor.status(); - } - const auto tokens = (*tokens_statusor)->template As<tensorflow::tstring, 1>(); - const int num_tokens = tokens.Dim(0); - - // Outputs - auto binarizations_statusor = - context->GetOutput(kOutputBinarizations, {num_tokens, bits_per_token}); - if (!binarizations_statusor.ok()) { - return binarizations_statusor.status(); - } - auto binarizations = (*binarizations_statusor)->template As<float, 2>(); - - // Iterate through all the token strings and binarize them. - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - float* row_start = &binarizations(token_idx, 0); - absl::Span<float> output_binarization(row_start, bits_per_token); - Utf8Binarize(tokens(token_idx), - /*word_length=*/word_length, - /*bits_per_char=*/bits_per_char, - /*replacement=*/replacement_char, - /*result=*/output_binarization); - } - - return absl::OkStatus(); -} - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/utf8_binarize_kernel_template.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_UTF8_BINARIZE_KERNEL_TEMPLATE_H_ diff --git a/tensorflow_text/core/kernels/utf8_binarize_test.cc b/tensorflow_text/core/kernels/utf8_binarize_test.cc deleted file mode 100644 index 9f61896c9..000000000 --- a/tensorflow_text/core/kernels/utf8_binarize_test.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/utf8_binarize.h" -#include <vector> - -#include <gmock/gmock.h> -#include "absl/types/span.h" - -namespace tensorflow { -namespace text { -namespace { - -using ::testing::ElementsAre; - -TEST(UnicodeTest, Utf8Binarize) { - std::vector<float> out1(3 * 4); - Utf8Binarize("hello", /*word_length=*/3, /*bits_per_char=*/4, - /*replacement=*/3, /*result=*/absl::MakeSpan(out1)); - // L-endian 4 lowest bits of: - EXPECT_THAT(out1, ElementsAre(0, 0, 0, 1, // "h" - 1, 0, 1, 0, // "e" - 0, 0, 1, 1)); // "l" - - std::vector<float> out2(4 * 5); - Utf8Binarize("爱上一个不回", /*word_length=*/4, /*bits_per_char=*/5, - /*replacement=*/7, /*result=*/absl::MakeSpan(out2)); - // L-endian 5 lowest bits of: - EXPECT_THAT(out2, ElementsAre(1, 0, 0, 0, 1, // "爱" - 0, 1, 0, 1, 0, // "上" - 0, 0, 0, 0, 0, // "一" - 0, 1, 0, 1, 0)); // "个" - - // Notable example: - // - (Unicode) characters are padded, not truncated as above (zero-padding); - // - the UTF-8 sequence is invalid, so we get a replacement bit pattern. - std::vector<float> out3(3 * 6); - Utf8Binarize("\xc3(", /*word_length=*/3, /*bits_per_char=*/6, - /*replacement=*/35, /*result=*/absl::MakeSpan(out3)); - // LE 6 lowest bits of: - EXPECT_THAT(out3, ElementsAre(1, 1, 0, 0, 0, 1, // Replacement. - 0, 0, 0, 1, 0, 1, // "(". - 0, 0, 0, 0, 0, 0)); // Padding. -} - -} // namespace -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/utf8_binarize_tflite.cc b/tensorflow_text/core/kernels/utf8_binarize_tflite.cc deleted file mode 100644 index 9bb6ac3c5..000000000 --- a/tensorflow_text/core/kernels/utf8_binarize_tflite.cc +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/utf8_binarize_tflite.h" - -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/shim/tflite_op_shim.h" -#include "tensorflow_text/core/kernels/utf8_binarize_kernel_template.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -extern "C" void AddUtf8Binarize(tflite::MutableOpResolver* resolver) { - tflite::shim::TfLiteOpKernel<tensorflow::text::Utf8BinarizeOp>::Add(resolver); -} - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow_text/core/kernels/utf8_binarize_tflite.h b/tensorflow_text/core/kernels/utf8_binarize_tflite.h index b4e145d1e..c34028803 100644 --- a/tensorflow_text/core/kernels/utf8_binarize_tflite.h +++ b/tensorflow_text/core/kernels/utf8_binarize_tflite.h @@ -12,22 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_TENSORFLOW_TEXT_GOOGLE_KERNELS_UTF8_BINARIZE_TFLITE_H_ -#define THIRD_PARTY_TENSORFLOW_TEXT_GOOGLE_KERNELS_UTF8_BINARIZE_TFLITE_H_ +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_UTF8_BINARIZE_TFLITE_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_UTF8_BINARIZE_TFLITE_H_ -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/core/kernels/text/utf8_binarize_tflite.h" -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -extern "C" void AddUtf8Binarize(::tflite::MutableOpResolver* resolver); - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite - -#endif // THIRD_PARTY_TENSORFLOW_TEXT_GOOGLE_KERNELS_UTF8_BINARIZE_TFLITE_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_UTF8_BINARIZE_TFLITE_H_ diff --git a/tensorflow_text/core/kernels/whitespace_tokenize_kernel.cc b/tensorflow_text/core/kernels/whitespace_tokenize_kernel.cc deleted file mode 100644 index dcdac0c5f..000000000 --- a/tensorflow_text/core/kernels/whitespace_tokenize_kernel.cc +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <string.h> - -#include <vector> - -#include "icu4c/source/common/unicode/uchar.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/lookup_interface.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace text { - -template <typename SPLITS_TYPE> -class WhitespaceTokenizeWithOffsetsOp : public OpKernel { - public: - explicit WhitespaceTokenizeWithOffsetsOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} - - /** - * Breaks a series of codepoints into individual groups based on the script - * code. - * - * We gain a dimension while tokenizing since a series of integer codepoints - * is tokenized into different codepoint groups. - * - * This accepts two input tensors: a rank 1 tensor of codepoint values and - * a single rank 1 tensor of splits which determine where each string begins - * and ends from the provided codepoints. - */ - void Compute(OpKernelContext* context) override { - // Get inputs - const Tensor& input_values_tensor = context->input(0); - const auto input_values_flat = input_values_tensor.flat<int32>(); - const Tensor& input_splits_tensor = context->input(1); - const auto input_splits_flat = input_splits_tensor.flat<SPLITS_TYPE>(); - - // Since we limit to a 2-D input (flat_values of rank 1 and a single splits - // tensor), our output dimension will always be 3-D (flat_values of rank 1 - // with two splits - inner for the tokenized values and the outer for those - // grouped by the original strings). - // A few things to note: - // 1) The values and inner splits of the tokenized strings have an unknown - // length, as well as the offsets, so we allocate them at the end. - // 2) The outer splits of the tokenized strings matches that of the offset - // splits. Thus, we will only return one set and use it for all of them. - // 3) The outer splits shape will match the original input_splits. - Tensor* output_outer_splits_tensor; - OP_REQUIRES_OK(context, - context->allocate_output("output_outer_splits", - input_splits_tensor.shape(), - &output_outer_splits_tensor)); - auto output_outer_splits_flat = - output_outer_splits_tensor->flat<SPLITS_TYPE>(); - - std::vector<int32> output_values; - std::vector<SPLITS_TYPE> output_values_inner_splits; - std::vector<int64> output_offset_starts; - std::vector<int64> output_offset_limits; - - // Loop over the codepoints (a split at a time) and create splits of tokens. - for (int splits_idx = 0; splits_idx < input_splits_flat.size() - 1; - splits_idx++) { - output_outer_splits_flat(splits_idx) = output_offset_starts.size(); - bool token_has_start_set = false; - int32 curr_skipped_spaces = 0; // Used when computing the end of a token - const int curr_word_start_idx = input_splits_flat(splits_idx); - for (int values_idx = curr_word_start_idx; - values_idx < input_splits_flat(splits_idx + 1); values_idx++) { - // End current token if we find whitespace - if (u_isUWhiteSpace(input_values_flat(values_idx))) { - if (token_has_start_set) { - output_offset_limits.push_back(values_idx - curr_word_start_idx - - curr_skipped_spaces); - } - token_has_start_set = false; - ++curr_skipped_spaces; - } else { - // Non whitespace. Start a new token if needed, and append the - // codepoint to our current token. - if (!token_has_start_set) { - // Set token start offset relative to current string. - output_offset_starts.push_back(values_idx - curr_word_start_idx); - // Set split to indicate start of a new token. - output_values_inner_splits.push_back(output_values.size()); - token_has_start_set = true; - } - output_values.push_back(input_values_flat(values_idx)); - curr_skipped_spaces = 0; - } - } - // Looping through the codepoints for current tokens complete. Now set the - // last limit of out last token (if we found a start earlier). - if (token_has_start_set) { - output_offset_limits.push_back(input_splits_flat(splits_idx + 1) - - curr_word_start_idx - - curr_skipped_spaces); - } - } - // Now set the closing value of our splits. - output_outer_splits_flat(input_splits_flat.size() - 1) = - output_offset_starts.size(); - output_values_inner_splits.push_back(output_values.size()); - -// Allocate output & fill output tensors. -#define DECLARE_ALLOCATE_AND_FILL_OUTPUT_TENSOR(name, dtype) \ - int64 name##_size = name.size(); \ - Tensor* name##_tensor = nullptr; \ - OP_REQUIRES_OK(context, \ - context->allocate_output(#name, TensorShape({name##_size}), \ - &name##_tensor)); \ - auto name##_data = name##_tensor->flat<dtype>().data(); \ - memcpy(name##_data, name.data(), name##_size * sizeof(dtype)); - - DECLARE_ALLOCATE_AND_FILL_OUTPUT_TENSOR(output_values, int32); - DECLARE_ALLOCATE_AND_FILL_OUTPUT_TENSOR(output_values_inner_splits, - SPLITS_TYPE); - DECLARE_ALLOCATE_AND_FILL_OUTPUT_TENSOR(output_offset_starts, int64); - DECLARE_ALLOCATE_AND_FILL_OUTPUT_TENSOR(output_offset_limits, int64); - -#undef DECLARE_ALLOCATE_AND_FILL_OUTPUT_TENSOR - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(WhitespaceTokenizeWithOffsetsOp); -}; - -REGISTER_KERNEL_BUILDER(Name("WhitespaceTokenizeWithOffsets") - .Device(DEVICE_CPU) - .TypeConstraint<int32>("Tsplits"), - WhitespaceTokenizeWithOffsetsOp<int32>); -REGISTER_KERNEL_BUILDER(Name("WhitespaceTokenizeWithOffsets") - .Device(DEVICE_CPU) - .TypeConstraint<int64>("Tsplits"), - WhitespaceTokenizeWithOffsetsOp<int64>); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/whitespace_tokenize_kernel_test.cc b/tensorflow_text/core/kernels/whitespace_tokenize_kernel_test.cc deleted file mode 100644 index d9792be45..000000000 --- a/tensorflow_text/core/kernels/whitespace_tokenize_kernel_test.cc +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <vector> - -#include <gmock/gmock.h> -#include <gtest/gtest.h> -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow_text/core/kernels/text_kernels_test_util.h" - -namespace tensorflow { -namespace text { - -using tensorflow::FakeInput; -using tensorflow::NodeDefBuilder; -using tensorflow::Status; -using tensorflow::TensorShape; -using tensorflow::text_kernels_test_util::VectorEq; - -class WhitespaceTokenizeWithOffsetsKernelTest - : public tensorflow::OpsTestBase { - public: - void MakeOp() { - TF_ASSERT_OK(NodeDefBuilder("tested_op", "WhitespaceTokenizeWithOffsets") - .Input(FakeInput()) - .Input(FakeInput()) - .Finalize(node_def())); - TF_ASSERT_OK(InitOp()); - } -}; - -TEST_F(WhitespaceTokenizeWithOffsetsKernelTest, Test) { - MakeOp(); - AddInputFromArray<int32_t>(TensorShape({6}), {111, 112, 32, 116, 117, 118}); - AddInputFromArray<int64_t>(TensorShape({3}), {0, 4, 6}); - TF_ASSERT_OK(RunOpKernel()); - - std::vector<int32_t> expected_values({111, 112, 116, 117, 118}); - std::vector<int64_t> expected_values_inner_splits({0, 2, 3, 5}); - std::vector<int64_t> expected_offset_starts({0, 3, 0}); - std::vector<int64_t> expected_offset_limits({2, 4, 2}); - std::vector<int64_t> output_outer_splits({0, 2, 3}); - EXPECT_THAT(*GetOutput(0), VectorEq(expected_values)); - EXPECT_THAT(*GetOutput(1), VectorEq(expected_values_inner_splits)); - EXPECT_THAT(*GetOutput(2), VectorEq(expected_offset_starts)); - EXPECT_THAT(*GetOutput(3), VectorEq(expected_offset_limits)); - EXPECT_THAT(*GetOutput(4), VectorEq(output_outer_splits)); -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/whitespace_tokenizer.cc b/tensorflow_text/core/kernels/whitespace_tokenizer.cc deleted file mode 100644 index dfe7107fc..000000000 --- a/tensorflow_text/core/kernels/whitespace_tokenizer.cc +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/whitespace_tokenizer.h" - -#include <string> -#include <vector> - -#include "absl/strings/string_view.h" -#include "icu4c/source/common/unicode/appendable.h" -#include "icu4c/source/common/unicode/schriter.h" -#include "icu4c/source/common/unicode/uchar.h" -#include "icu4c/source/common/unicode/ucnv.h" -#include "icu4c/source/common/unicode/ucnv_err.h" -#include "icu4c/source/common/unicode/umachine.h" -#include "icu4c/source/common/unicode/uniset.h" -#include "icu4c/source/common/unicode/unistr.h" -#include "icu4c/source/common/unicode/uset.h" -#include "icu4c/source/common/unicode/utypes.h" -#include "icu4c/source/common/unicode/bytestream.h" -#include "icu4c/source/common/unicode/edits.h" -#include "icu4c/source/common/unicode/normalizer2.h" -#include "icu4c/source/common/unicode/stringoptions.h" -#include "icu4c/source/common/unicode/stringpiece.h" -#include "icu4c/source/common/unicode/utf.h" -#include "icu4c/source/common/unicode/utf8.h" - - -namespace tensorflow { -namespace text { - -void WhitespaceTokenizer::Tokenize(const absl::string_view input, - std::vector<std::string>* tokens) { - std::vector<int> start_offsets, end_offsets; - Tokenize(input, tokens, &start_offsets, &end_offsets); -} - -void WhitespaceTokenizer::Tokenize(const absl::string_view input, - std::vector<std::string>* tokens, - std::vector<int>* start_offsets, - std::vector<int>* end_offsets) { - const int input_size = input.size(); - int position = 0, prev_position = 0; - UChar32 codepoint; - bool inside_token = false; - while (position < input_size) { - prev_position = position; - U8_NEXT(input, position, input_size, codepoint); - if (config_.IsWhitespace(codepoint)) { - if (inside_token) { - int end_pos = position - 1; - end_offsets->push_back(end_pos); - int start_pos = start_offsets->back(); - std::string token(input.substr(start_pos, end_pos - start_pos)); - tokens->push_back(token); - inside_token = false; - } - } else { - if (!inside_token) { - start_offsets->push_back(prev_position); - inside_token = true; - } - } - } - // save final word - if (inside_token) { - int end_pos = position; - end_offsets->push_back(end_pos); - int start_pos = start_offsets->back(); - std::string token(input.substr(start_pos, end_pos - start_pos)); - tokens->push_back(token); - } -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/whitespace_tokenizer.h b/tensorflow_text/core/kernels/whitespace_tokenizer.h index b2b357500..21e776938 100644 --- a/tensorflow_text/core/kernels/whitespace_tokenizer.h +++ b/tensorflow_text/core/kernels/whitespace_tokenizer.h @@ -15,100 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_H_ -#include <string> -#include <vector> - -#include "absl/strings/string_view.h" -#include "icu4c/source/common/unicode/umachine.h" - -namespace tensorflow { -namespace text { - -// Helper class for working with the WhitespaceaTokenizer config. The -// config is essentially a bit array stored in characters, where each bit in -// the char represents a Unicode character and whether or not it is considered -// as whitespace. -// -// This bit array contains all codepoints up to the largest whitespace -// character. So any codepoint larger than the array is not whitespace, and -// a lookup is simply using the codepoint value as the index. The first 3 bits -// of the codepoint indicate which bit in a character is the value located, and -// using the rest of the bits of the codepoint we can determine which -// character the particular codepoint is located at. -class WhitespaceTokenizerConfig { - public: - // This object does not own the config, so make certain it exists for the - // lifetime of the class. - WhitespaceTokenizerConfig(const absl::string_view config) - : config_(config), max_codepoint_(config.length() * 8) {} - WhitespaceTokenizerConfig(const std::string* config) - : config_(*config), max_codepoint_(config->length() * 8) {} - - inline bool IsWhitespace(const UChar32 codepoint) const { - return codepoint != U_SENTINEL && - codepoint < max_codepoint_ && - config_[codepoint >> 3] & (1 << (char)(codepoint & 0x7)); - } - - private: - const absl::string_view config_; - const int max_codepoint_; -}; - -class WhitespaceTokenizer { - public: - // Creates an instance. - // - // Args: - // * config: A WhitespaceTokenizerConfig which should be created using the - // WhitespaceTokenizerConfigBuilder - WhitespaceTokenizer(const WhitespaceTokenizerConfig& cfg) - : config_(cfg) { } - - // Tokenizes a string (or series of character codepoints) by whitespace. - // - // Example: - // input = "Show me the way." - // tokens = ["Show", "me", "the", "way."] - // start_offsets = [0, 5, 8, 12] - // end_offsets = [4, 7, 11, 16] - // - // The input should be UTF-8 but the tokenization is performed on Unicode - // codepoints. - // - // Args: - // * input: The UTF-8 string of an input. - // * tokens: The output tokens. - // * start_offsets: The start offsets of output tokens in the input - // text, in utf-8 bytes. - // * end_offsets: The end offsets of output tokens in the input - // text, in utf-8 bytes. - // Note: the start offsets are inclusive and the end offsets are exclusive. - void Tokenize(const absl::string_view input, - std::vector<std::string>* tokens, - std::vector<int>* start_offsets, - std::vector<int>* end_offsets); - - // Tokenizes a string (or series of character codepoints) by whitespace. - // - // Example: - // input = "Show me the way." - // output = ["Show", "me", "the", "way."] - // - // The input should be UTF-8 but the tokenization is performed on Unicode - // codepoints. - // - // Args: - // * input: The UTF-8 string of an input. - // * tokens: The output tokens. - void Tokenize(const absl::string_view input, - std::vector<std::string>* tokens); - - private: - const WhitespaceTokenizerConfig config_; -}; - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/whitespace_tokenizer.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_H_ diff --git a/tensorflow_text/core/kernels/whitespace_tokenizer_config_builder.cc b/tensorflow_text/core/kernels/whitespace_tokenizer_config_builder.cc deleted file mode 100644 index db4a17063..000000000 --- a/tensorflow_text/core/kernels/whitespace_tokenizer_config_builder.cc +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/whitespace_tokenizer_config_builder.h" - -#include <string> - -#include "icu4c/source/common/unicode/uchar.h" -#include "icu4c/source/common/unicode/umachine.h" -#include "icu4c/source/common/unicode/uniset.h" -#include "icu4c/source/common/unicode/uset.h" -#include "icu4c/source/common/unicode/utf8.h" -#include "icu4c/source/common/unicode/utypes.h" - -namespace tensorflow { -namespace text { - -namespace { - -const icu::UnicodeSet& WhiteSpaceSet() { - // Will not fail because the data is hardcoded in the ICU library. - UErrorCode error_code = U_ZERO_ERROR; - const USet* c_set = u_getBinaryPropertySet(UCHAR_WHITE_SPACE, &error_code); - // assert(U_SUCCESS(error_code)); - const icu::UnicodeSet* set = icu::UnicodeSet::fromUSet(c_set); - return *set; -} - -} // namespace - -std::string BuildWhitespaceString() { - std::string str; - char buf[U8_MAX_LENGTH]; - for (auto cp : WhiteSpaceSet().codePoints()) { - int len = 0; - U8_APPEND_UNSAFE(buf, len, cp); - str.append(buf, len); - } - return str; -} - -std::string BuildWhitespaceTokenizerConfig() { - const icu::UnicodeSet& set = WhiteSpaceSet(); - int range_count = set.getRangeCount(); - UChar32 largest_whitespace = set.getRangeEnd(range_count - 1); - // The string will hold our bit array - std::string bitset((largest_whitespace >> 3) + 1, 0); - for (auto cp : set.codePoints()) { - int index = cp >> 3; - bitset[index] |= 1 << (cp & 7); - } - return bitset; -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/whitespace_tokenizer_config_builder.h b/tensorflow_text/core/kernels/whitespace_tokenizer_config_builder.h index 60a3ca092..e11425fff 100644 --- a/tensorflow_text/core/kernels/whitespace_tokenizer_config_builder.h +++ b/tensorflow_text/core/kernels/whitespace_tokenizer_config_builder.h @@ -15,30 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_CONFIG_BUILDER_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_CONFIG_BUILDER_H_ -#include <string> - - -namespace tensorflow { -namespace text { - -// Builds a WhitespaceTokenizer config object. This contains the Unicode -// codepoints which are considered whitespaces. -// -// The config object is a series of bytes, where each bit represents a Unicode -// character and is 1 if it is a whitespace character, and 0 otherwise. -// -// Returns: -// The bytes of the config as a string. -std::string BuildWhitespaceTokenizerConfig(); - -// Builds a string full of all the whitespace characters. It is mainly used -// for testing and validation. -// -// Returns: -// A string of Unicode whitespace characters. -std::string BuildWhitespaceString(); - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/whitespace_tokenizer_config_builder.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_CONFIG_BUILDER_H_ diff --git a/tensorflow_text/core/kernels/whitespace_tokenizer_config_builder_test.cc b/tensorflow_text/core/kernels/whitespace_tokenizer_config_builder_test.cc deleted file mode 100644 index 9c8a2724b..000000000 --- a/tensorflow_text/core/kernels/whitespace_tokenizer_config_builder_test.cc +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/whitespace_tokenizer_config_builder.h" - -#include <string> - -#include <gmock/gmock.h> -#include <gtest/gtest.h> -#include "icu4c/source/common/unicode/appendable.h" -#include "icu4c/source/common/unicode/bytestream.h" -#include "icu4c/source/common/unicode/edits.h" -#include "icu4c/source/common/unicode/normalizer2.h" -#include "icu4c/source/common/unicode/schriter.h" -#include "icu4c/source/common/unicode/stringoptions.h" -#include "icu4c/source/common/unicode/stringpiece.h" -#include "icu4c/source/common/unicode/uchar.h" -#include "icu4c/source/common/unicode/ucnv.h" -#include "icu4c/source/common/unicode/ucnv_err.h" -#include "icu4c/source/common/unicode/umachine.h" -#include "icu4c/source/common/unicode/uniset.h" -#include "icu4c/source/common/unicode/unistr.h" -#include "icu4c/source/common/unicode/uset.h" -#include "icu4c/source/common/unicode/utf.h" -#include "icu4c/source/common/unicode/utf8.h" -#include "icu4c/source/common/unicode/utypes.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow_text/core/kernels/whitespace_tokenizer.h" - -namespace tensorflow { -namespace text { -namespace { - -TEST(WhitespaceTokenizerConfigBuilderTest, BuildWhitespaceString) { - std::string result = BuildWhitespaceString(); - EXPECT_THAT(result, ::testing::HasSubstr(" ")); - EXPECT_THAT(result, ::testing::HasSubstr("\n")); -} - -TEST(WhitespaceTokenizerConfigBuilderTest, - BuildWhitespaceTokenizerConfig_AllWhitespacePresent) { - std::string whitespaces = BuildWhitespaceString(); - icu::UnicodeString codepoints = icu::UnicodeString::fromUTF8(whitespaces); - std::string config = BuildWhitespaceTokenizerConfig(); - // verify all whitepaces are present - WhitespaceTokenizerConfig cfg(config); - for (int i = 0; i < codepoints.length(); ++i) { - EXPECT_TRUE(cfg.IsWhitespace(codepoints[i])); - } -} - -TEST(WhitespaceTokenizerConfigBuilderTest, - BuildWhitespaceTokenizerConfig_MinSize) { - std::string whitespaces = BuildWhitespaceString(); - icu::UnicodeString codepoints = icu::UnicodeString::fromUTF8(whitespaces); - std::string config = BuildWhitespaceTokenizerConfig(); - // verify we are the minimum perfect hash - auto largest_cp = codepoints[codepoints.length() - 1]; - EXPECT_EQ(config.length(), (largest_cp / 8) + 1); -} - -TEST(WhitespaceTokenizerConfigBuilderTest, - BuildWhitespaceTokenizerConfig_VerifyCount) { - std::string whitespaces = BuildWhitespaceString(); - icu::UnicodeString codepoints = icu::UnicodeString::fromUTF8(whitespaces); - std::string config = BuildWhitespaceTokenizerConfig(); - // verify we have the correct number of true values (rest will be false) - int count = 0; - WhitespaceTokenizerConfig cfg(config); - for (int i = 0; i < config.length() * 8; ++i) { - count += cfg.IsWhitespace(i) ? 1 : 0; - } - EXPECT_EQ(count, codepoints.length()); -} - -} // namespace -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/whitespace_tokenizer_kernel.cc b/tensorflow_text/core/kernels/whitespace_tokenizer_kernel.cc deleted file mode 100644 index 78a90c02b..000000000 --- a/tensorflow_text/core/kernels/whitespace_tokenizer_kernel.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/whitespace_tokenizer_kernel.h" - -#include "tensorflow/core/framework/op_kernel.h" - -namespace tensorflow { -namespace text { - -REGISTER_KERNEL_BUILDER(Name(WhitespaceTokenizeWithOffsetsV2OpKernel::OpName()) - .Device(tensorflow::DEVICE_CPU), - WhitespaceTokenizeWithOffsetsV2OpKernel); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/whitespace_tokenizer_kernel.h b/tensorflow_text/core/kernels/whitespace_tokenizer_kernel.h index 00eae4b3f..97ce10b9c 100644 --- a/tensorflow_text/core/kernels/whitespace_tokenizer_kernel.h +++ b/tensorflow_text/core/kernels/whitespace_tokenizer_kernel.h @@ -12,22 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZE_KERNEL_H_ -#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZE_KERNEL_H_ +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_KERNEL_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_KERNEL_H_ -#include "tensorflow/lite/kernels/shim/tf_op_shim.h" -#include "tensorflow_text/core/kernels/whitespace_tokenizer_kernel_template.h" +#include "tensorflow/core/kernels/text/whitespace_tokenizer_kernel.h" -namespace tensorflow { -namespace text { - -class WhitespaceTokenizeWithOffsetsV2OpKernel - : public tflite::shim::TfOpKernel<WhitespaceTokenizeWithOffsetsV2Op> { - public: - using TfOpKernel::TfOpKernel; -}; - -} // namespace text -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZE_KERNEL_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_KERNEL_H_ diff --git a/tensorflow_text/core/kernels/whitespace_tokenizer_kernel_template.h b/tensorflow_text/core/kernels/whitespace_tokenizer_kernel_template.h index e05a914bb..1682fee2e 100644 --- a/tensorflow_text/core/kernels/whitespace_tokenizer_kernel_template.h +++ b/tensorflow_text/core/kernels/whitespace_tokenizer_kernel_template.h @@ -15,180 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_KERNEL_TEMPLATE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_KERNEL_TEMPLATE_H_ -#include <iostream> -#include <vector> - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/lite/kernels/shim/op_kernel.h" -#include "tensorflow/lite/kernels/shim/shape.h" -#include "tensorflow/lite/kernels/shim/status_macros.h" -#include "tensorflow/lite/kernels/shim/tensor_view.h" -#include "tensorflow_text/core/kernels/whitespace_tokenizer.h" - -namespace tensorflow { -namespace text { - -template <tflite::shim::Runtime Rt> -class WhitespaceTokenizeWithOffsetsV2Op - : public tflite::shim::OpKernelShim<WhitespaceTokenizeWithOffsetsV2Op, Rt> { - private: - enum Inputs { - kInputValues = 0, - kInputConfig - }; - enum Outputs { - kOutputTokens = 0, - kOutputRowSplits, - kOutputStartOffsets, - kOutputEndOffsets - }; - - using typename tflite::shim::OpKernelShim<WhitespaceTokenizeWithOffsetsV2Op, - Rt>::InitContext; - using typename tflite::shim::OpKernelShim<WhitespaceTokenizeWithOffsetsV2Op, - Rt>::InvokeContext; - using typename tflite::shim::OpKernelShim<WhitespaceTokenizeWithOffsetsV2Op, - Rt>::ShapeInferenceContext; - - public: - WhitespaceTokenizeWithOffsetsV2Op() = default; - static constexpr char kOpName[] = "TFText>WhitespaceTokenizeWithOffsetsV2"; - static constexpr char kDoc[] = R"doc( - Splits a string into tokens based off of Unicode whitespaces. It also returns - the relative byte offsets for each token. - - ### Example: - - ```python - >>> splitter = WhitespaceTokenizer() - >>> tokens, starts, ends = splitter.tokenize_with_offsets("a bb ccc") - >>> print(tokens.numpy(), starts.numpy(), ends.numpy()) - [b'a' b'bb' b'ccc'] [0 2 5] [1 4 8] - ``` - - Args: - input_values: 1D Tensor of strings to tokenize. - input_config: A string representing a WhitespaceTokenizerConfig. - - Returns: - * output_tokens: 1D tensor containing the tokens for all input strings. - A 2D RaggedTensor can be constructed from this and output_row_splits. - * output_row_splits: 1D int tensor with the row splits that allow us to - build RaggedTensors from output_tokens, output_start_offsets, and - output_end_offsets. - * output_start_offsets: 1D tensor containing the inclusive start byte offset - for each token in all input strings. Corresponds 1:1 with output_tokens. - A 2D RaggedTensor can be constructed from this and output_row_splits. - * output_end_offsets: 1D tensor containing the exclusive end byte offset for - each token in all input strings. Corresponds 1:1 with output_tokens. - A 2D RaggedTensor can be constructed from this and output_row_splits. - )doc"; - - static const char* OpName() { return kOpName; } - static const char* Doc() { return kDoc; } - - // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector<std::string> Attrs() { return {}; } - - // Inputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector<std::string> Inputs(); - - // Outputs declaration (syntax: https://www.tensorflow.org/guide/create_op) - static std::vector<std::string> Outputs(); - - // Initializes the op - absl::Status Init(InitContext* context) { return absl::OkStatus(); } - - // Runs the operation - absl::Status Invoke(InvokeContext* context); - - // Shape inference - static absl::Status ShapeInference(ShapeInferenceContext* c); -}; - -template <tflite::shim::Runtime Rt> -std::vector<std::string> WhitespaceTokenizeWithOffsetsV2Op<Rt>::Inputs() { - return {"input_values: string", "input_config: string"}; -} - -template <tflite::shim::Runtime Rt> -std::vector<std::string> WhitespaceTokenizeWithOffsetsV2Op<Rt>::Outputs() { - return {"output_tokens: string", "output_row_splits: int64", - "output_start_offsets: int32", "output_end_offsets: int32"}; -} - -template <tflite::shim::Runtime Rt> -absl::Status WhitespaceTokenizeWithOffsetsV2Op<Rt>::ShapeInference( - ShapeInferenceContext* c) { - using tflite::shim::Shape; - const auto input_values_shape_status = c->GetInputShape(kInputValues); - if (!input_values_shape_status.ok()) { - return input_values_shape_status.status(); - } - const Shape& input_values_shape = *input_values_shape_status; - - const auto rank_1_shape = Shape({Shape::kUnknownDim}); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputTokens, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputStartOffsets, rank_1_shape)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputEndOffsets, rank_1_shape)); - const int num_splits = Shape::AddDims(1, input_values_shape.Dim(0)); - SH_RETURN_IF_ERROR(c->SetOutputShape(kOutputRowSplits, Shape({num_splits}))); - - return absl::OkStatus(); -} - -template <tflite::shim::Runtime Rt> - absl::Status WhitespaceTokenizeWithOffsetsV2Op<Rt> - ::Invoke(InvokeContext* context) { - // Inputs - const auto values_statusor = context->GetInput(kInputValues); - if (!values_statusor.ok()) { - return values_statusor.status(); - } - const auto values = (*values_statusor)->template As<tensorflow::tstring, 1>(); - - const auto cfg_statusor = context->GetInput(kInputConfig); - if (!cfg_statusor.ok()) { - return cfg_statusor.status(); - } - const absl::string_view config = - (*cfg_statusor)->template AsScalar<tensorflow::tstring>(); - WhitespaceTokenizer tokenizer(config); - - // Outputs - std::vector<std::string> tokens; - std::vector<int64_t> row_splits; - std::vector<int32_t> start_offsets; - std::vector<int32_t> end_offsets; - - // Iterate through all the values and wordpiece tokenize them. - row_splits.push_back(0); - for (int i = 0; i < values.Dim(0); ++i) { - // Tokenize into subwords and record the offset locations. - const int orig_num_tokens = tokens.size(); - tokenizer.Tokenize(values(i), &tokens, &start_offsets, &end_offsets); - const int delta_num_tokens = tokens.size() - orig_num_tokens; - // Record the row splits. - row_splits.push_back(delta_num_tokens + row_splits.back()); - } - - // Allocate output & fill output tensors. - SH_RETURN_IF_ERROR(this->template FillOutputTensor<std::string, - tensorflow::tstring>( - tokens, kOutputTokens, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor<int64_t, int64_t>( - row_splits, kOutputRowSplits, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor<int32_t, int32_t>( - start_offsets, kOutputStartOffsets, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor<int32_t, int32_t>( - end_offsets, kOutputEndOffsets, context)); - - return absl::OkStatus(); -} - -} // namespace text -} // namespace tensorflow +#include "tensorflow/core/kernels/text/whitespace_tokenizer_kernel_template.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_KERNEL_TEMPLATE_H_ diff --git a/tensorflow_text/core/kernels/whitespace_tokenizer_test.cc b/tensorflow_text/core/kernels/whitespace_tokenizer_test.cc deleted file mode 100644 index aa94839f7..000000000 --- a/tensorflow_text/core/kernels/whitespace_tokenizer_test.cc +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/whitespace_tokenizer.h" - -#include <gmock/gmock.h> -#include <gtest/gtest.h> -#include "absl/flags/flag.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow_text/core/kernels/whitespace_tokenizer_config_builder.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" - -namespace tensorflow { -namespace text { -namespace { - -using ::testing::ElementsAre; - -TEST(WhitespaceTokenizerTest, TokenizeWithOffsets) { - absl::string_view input("I heard the news today"); - std::vector<std::string> output_tokens; - std::vector<int> output_start_offsets; - std::vector<int> output_end_offsets; - std::string config(BuildWhitespaceTokenizerConfig()); - WhitespaceTokenizer t(&config); - t.Tokenize(input, &output_tokens, &output_start_offsets, &output_end_offsets); - EXPECT_THAT(output_tokens, ElementsAre("I", "heard", "the", "news", "today")); - EXPECT_THAT(output_start_offsets, ElementsAre(0, 2, 8, 12, 17)); - EXPECT_THAT(output_end_offsets, ElementsAre(1, 7, 11, 16, 22)); -} - -TEST(WhitespaceTokenizerTest, Tokenize) { - absl::string_view input("I heard the news today"); - std::vector<std::string> output_tokens; - std::string config = BuildWhitespaceTokenizerConfig(); - WhitespaceTokenizer t(&config); - t.Tokenize(input, &output_tokens); - EXPECT_THAT(output_tokens, ElementsAre("I", "heard", "the", "news", "today")); -} - -TEST(WhitespaceTokenizerTest, Internationalization) { - absl::string_view input("la灯 灯a 瀮b"); - std::vector<std::string> output_tokens; - std::vector<int> output_start_offsets; - std::vector<int> output_end_offsets; - std::string config = BuildWhitespaceTokenizerConfig(); - WhitespaceTokenizer t(&config); - t.Tokenize(input, &output_tokens, &output_start_offsets, &output_end_offsets); - EXPECT_THAT(output_start_offsets, ElementsAre(0, 6, 11)); - EXPECT_THAT(output_end_offsets, ElementsAre(5, 10, 15)); -} - -TEST(WhitespaceTokenizerTest, InvalidCodepoint) { - absl::string_view input("\xE3"); - std::vector<std::string> output_tokens; - std::vector<int> output_start_offsets; - std::vector<int> output_end_offsets; - std::string config = BuildWhitespaceTokenizerConfig(); - WhitespaceTokenizer t(&config); - t.Tokenize(input, &output_tokens, &output_start_offsets, &output_end_offsets); - EXPECT_THAT(output_start_offsets, ElementsAre(0)); - EXPECT_THAT(output_end_offsets, ElementsAre(1)); -} - -TEST(WhitespaceTokenizerTest, MaxCodepoint) { - // Create an artificially-small config so that we can test behavior with - // codepoints at the upper edge of its range. This bitmap marks 0x00-0x3f as - // whitespace. - std::string config(8, '\xff'); - // Verify that reading one bit off the end of the bitmap returns - // not-whitespace. - WhitespaceTokenizerConfig cfg(config); - EXPECT_FALSE(cfg.IsWhitespace(0x40)); -} - -} // namespace -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/whitespace_tokenizer_tflite.cc b/tensorflow_text/core/kernels/whitespace_tokenizer_tflite.cc deleted file mode 100644 index 58c09915a..000000000 --- a/tensorflow_text/core/kernels/whitespace_tokenizer_tflite.cc +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/whitespace_tokenizer_tflite.h" - -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/shim/tflite_op_shim.h" -#include "tensorflow_text/core/kernels/whitespace_tokenizer_kernel_template.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -extern "C" void AddWhitespaceTokenize(tflite::MutableOpResolver* resolver) { - tflite::shim::TfLiteOpKernel< - tensorflow::text::WhitespaceTokenizeWithOffsetsV2Op>::Add(resolver); -} - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite diff --git a/tensorflow_text/core/kernels/whitespace_tokenizer_tflite.h b/tensorflow_text/core/kernels/whitespace_tokenizer_tflite.h index 85b003e00..e3b5aeae9 100644 --- a/tensorflow_text/core/kernels/whitespace_tokenizer_tflite.h +++ b/tensorflow_text/core/kernels/whitespace_tokenizer_tflite.h @@ -15,19 +15,6 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_TFLITE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_TFLITE_H_ -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/mutable_op_resolver.h" - -namespace tflite { -namespace ops { -namespace custom { -namespace text { - -extern "C" void AddWhitespaceTokenize(::tflite::MutableOpResolver* resolver); - -} // namespace text -} // namespace custom -} // namespace ops -} // namespace tflite +#include "tensorflow/core/kernels/text/whitespace_tokenizer_tflite.h" #endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WHITESPACE_TOKENIZER_TFLITE_H_ diff --git a/tensorflow_text/core/kernels/wordpiece_kernel.cc b/tensorflow_text/core/kernels/wordpiece_kernel.cc deleted file mode 100644 index 8863d80ab..000000000 --- a/tensorflow_text/core/kernels/wordpiece_kernel.cc +++ /dev/null @@ -1,317 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <limits> -#include <memory> -#include <string> -#include <vector> - -#include "tensorflow/core/framework/dataset_stateful_op_allowlist.h" -#include "tensorflow/core/framework/lookup_interface.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/public/version.h" -#include "tensorflow_text/core/kernels/wordpiece_tokenizer.h" - -namespace tensorflow { -namespace text { - -namespace { -string GetWordSplitChar(OpKernelConstruction* ctx) { - string suffix_indicator; - ([=](string* c) -> void { - OP_REQUIRES_OK(ctx, ctx->GetAttr("suffix_indicator", c)); - })(&suffix_indicator); - return suffix_indicator; -} - -int32 GetMaxCharsPerWord(OpKernelConstruction* ctx) { - int32 max_chars_per_word; - ([=](int32* c) -> void { - OP_REQUIRES_OK(ctx, ctx->GetAttr("max_bytes_per_word", c)); - })(&max_chars_per_word); - return max_chars_per_word; -} - -int32 GetMaxCharsPerToken(OpKernelConstruction* ctx) { - int32 max_chars_per_token; - ([=](int32* c) -> void { - OP_REQUIRES_OK(ctx, ctx->GetAttr("max_chars_per_token", c)); - })(&max_chars_per_token); - return max_chars_per_token; -} - -bool GetShouldUseUnknownToken(OpKernelConstruction* ctx) { - bool use_unknown_token; - ([=](bool* c) -> void { - OP_REQUIRES_OK(ctx, ctx->GetAttr("use_unknown_token", c)); - })(&use_unknown_token); - return use_unknown_token; -} - -string GetUnknownToken(OpKernelConstruction* ctx) { - string unknown_token; - ([=](string* c) -> void { - OP_REQUIRES_OK(ctx, ctx->GetAttr("unknown_token", c)); - })(&unknown_token); - return unknown_token; -} - -bool GetSplitUnknownCharacters(OpKernelConstruction* ctx) { - bool split_unknown_characters; - ([=](bool* c) -> void { - OP_REQUIRES_OK(ctx, ctx->GetAttr("split_unknown_characters", c)); - })(&split_unknown_characters); - return split_unknown_characters; -} - -Status GetTableHandle(const string& input_name, OpKernelContext* ctx, - string* container, string* table_handle) { - { - mutex* mu; - TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu)); - mutex_lock l(*mu); - Tensor tensor; - TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true)); - if (tensor.NumElements() != 2) { - return errors::InvalidArgument( - "Lookup table handle must be scalar, but had shape: ", - tensor.shape().DebugString()); - } - auto h = tensor.flat<tstring>(); - *container = h(0); - *table_handle = h(1); - } - return absl::OkStatus(); -} - -// Gets the LookupTable stored in the ctx->resource_manager() with key -// passed by attribute with name input_name, returns null if the table -// doesn't exist. -Status GetLookupTable(const string& input_name, OpKernelContext* ctx, - lookup::LookupInterface** table) { - string container; - string table_handle; - DataType handle_dtype; - TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype)); - if (handle_dtype == DT_RESOURCE) { - ResourceHandle handle; - TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle)); - return LookupResource(ctx, handle, table); - } else { - TF_RETURN_IF_ERROR( - GetTableHandle(input_name, ctx, &container, &table_handle)); - return ctx->resource_manager()->Lookup(container, table_handle, table); - } -} - -class LookupTableVocab : public WordpieceVocab { - public: - LookupTableVocab(lookup::LookupInterface* table, OpKernelContext* ctx); - - virtual LookupStatus Contains(const absl::string_view key, bool* value) const; - - private: - // not owned - mutable lookup::LookupInterface* table_; - OpKernelContext* ctx_; - Tensor default_value_; -}; - -Status ToStatus(const LookupStatus& status) { - if (status.success) { - return absl::OkStatus(); - } - - return errors::InvalidArgument(status.error_msg); -} - -constexpr int64 kOutOfVocabValue = -1; - -LookupTableVocab::LookupTableVocab(lookup::LookupInterface* table, - OpKernelContext* ctx) - : table_(table), ctx_(ctx), default_value_(DT_INT64, TensorShape({1})) { - default_value_.flat<int64>()(0) = kOutOfVocabValue; -} - -LookupStatus LookupTableVocab::Contains(const absl::string_view key, - bool* value) const { - if (value == nullptr) { - return LookupStatus("Bad 'value' param."); - } - Tensor keys(DT_STRING, TensorShape({1})); - keys.flat<tstring>()(0) = tstring(key.data(), key.size()); - Tensor values(DT_INT64, TensorShape({1})); - auto status = table_->Find(ctx_, keys, &values, default_value_); - if (!status.ok()) { -// On April 2023, there is not yet an official release of Tensorflow which -// includes `message().` One will need to wait for the release following 2.12.0. -// The code can be updated to just be the else branch after such release exists. -#if TF_GRAPH_DEF_VERSION < 1467 - return LookupStatus(std::string(status.error_message())); -#else - return LookupStatus(std::string(status.message())); -#endif - } - - if (static_cast<int64>(values.flat<int64>()(0)) != kOutOfVocabValue) { - *value = true; - return LookupStatus::OK(); - } - *value = false; - return LookupStatus::OK(); -} - -} // namespace - -class WordpieceTokenizeWithOffsetsOp : public OpKernel { - public: - explicit WordpieceTokenizeWithOffsetsOp(OpKernelConstruction* ctx) - : OpKernel(ctx), - suffix_indicator_(GetWordSplitChar(ctx)), - max_bytes_per_word_(GetMaxCharsPerWord(ctx)), - max_chars_per_token_(GetMaxCharsPerToken(ctx)), - use_unknown_token_(GetShouldUseUnknownToken(ctx)), - unknown_token_(GetUnknownToken(ctx)), - split_unknown_characters_(GetSplitUnknownCharacters(ctx)) { - string output_row_partition_type; - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_row_partition_type", - &output_row_partition_type)); - if (output_row_partition_type == "row_lengths") { - row_partition_type_ = ROW_LENGTHS; - } else if (output_row_partition_type == "row_splits") { - row_partition_type_ = ROW_SPLITS; - } else { - OP_REQUIRES( - ctx, false, - errors::Internal("Unexpected value for output_row_partition_type")); - } - } - - void Compute(OpKernelContext* ctx) override { - const Tensor* input_values; - OP_REQUIRES_OK(ctx, ctx->input("input_values", &input_values)); - const auto& values_vec = input_values->flat<tstring>(); - - lookup::LookupInterface* lookup_table; - OP_REQUIRES_OK(ctx, - GetLookupTable("vocab_lookup_table", ctx, &lookup_table)); - core::ScopedUnref unref_me(lookup_table); - LookupTableVocab vocab_map(lookup_table, ctx); - - std::vector<string> subwords; - std::vector<int> begin_offset; - std::vector<int> end_offset; - std::vector<int> row_partition; - - if (row_partition_type_ == ROW_SPLITS) { - row_partition.push_back(0); - } - - // Iterate through all the values and wordpiece tokenize them. - for (int i = 0; i < values_vec.size(); ++i) { - // Tokenize into subwords and record the offset locations. - int num_wordpieces = 0; - OP_REQUIRES_OK( - ctx, ToStatus(WordpieceTokenize( - values_vec(i), max_bytes_per_word_, max_chars_per_token_, - suffix_indicator_, use_unknown_token_, unknown_token_, - split_unknown_characters_, &vocab_map, &subwords, - &begin_offset, &end_offset, &num_wordpieces))); - - // Record the row splits. - switch (row_partition_type_) { - case ROW_LENGTHS: - row_partition.push_back(num_wordpieces); - break; - case ROW_SPLITS: - row_partition.push_back(num_wordpieces + row_partition.back()); - break; - } - } - - std::vector<int64> output_subwords_shape; - output_subwords_shape.push_back(subwords.size()); - - std::vector<int64> output_row_partition_shape; - output_row_partition_shape.push_back(row_partition.size()); - - Tensor* output_values; - OP_REQUIRES_OK(ctx, ctx->allocate_output("output_values", - TensorShape(output_subwords_shape), - &output_values)); - auto output_values_vec = output_values->vec<tstring>(); - - Tensor* output_row_partition; - OP_REQUIRES_OK(ctx, - ctx->allocate_output("output_row_lengths", - TensorShape(output_row_partition_shape), - &output_row_partition)); - auto output_row_partition_vec = output_row_partition->vec<int64>(); - - Tensor* start_values; - OP_REQUIRES_OK(ctx, ctx->allocate_output("start_values", - TensorShape(output_subwords_shape), - &start_values)); - auto start_values_vec = start_values->vec<int64>(); - - Tensor* limit_values; - OP_REQUIRES_OK(ctx, ctx->allocate_output("limit_values", - TensorShape(output_subwords_shape), - &limit_values)); - auto limit_values_vec = limit_values->vec<int64>(); - - for (int i = 0; i < subwords.size(); ++i) { - output_values_vec(i) = subwords[i]; - } - - for (int i = 0; i < row_partition.size(); ++i) { - output_row_partition_vec(i) = row_partition[i]; - } - - for (int i = 0; i < begin_offset.size(); ++i) { - start_values_vec(i) = begin_offset[i]; - } - - for (int i = 0; i < end_offset.size(); ++i) { - limit_values_vec(i) = end_offset[i]; - } - } - - private: - enum RowPartitionType { ROW_LENGTHS, ROW_SPLITS }; - - const string suffix_indicator_; - const int max_bytes_per_word_; - const int max_chars_per_token_; - const bool use_unknown_token_; - const string unknown_token_; - const bool split_unknown_characters_; - RowPartitionType row_partition_type_; - - TF_DISALLOW_COPY_AND_ASSIGN(WordpieceTokenizeWithOffsetsOp); -}; - -REGISTER_KERNEL_BUILDER(Name("WordpieceTokenizeWithOffsets").Device(DEVICE_CPU), - WordpieceTokenizeWithOffsetsOp); -ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("WordpieceTokenizeWithOffsets"); - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/wordpiece_kernel_test.cc b/tensorflow_text/core/kernels/wordpiece_kernel_test.cc deleted file mode 100644 index d9b81677a..000000000 --- a/tensorflow_text/core/kernels/wordpiece_kernel_test.cc +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/shape_inference_testutil.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { - -TEST(WordpieceTokenizeWithOffsetsOpTest, ShapeFn) { - // WordpieceTokenizeWithOffsets(input_values, vocab_lookup_table) -> - // [output_values, output_row_lengths, start_values, limit_values] - ShapeInferenceTestOp op("WordpieceTokenizeWithOffsets"); - auto &attr = *op.node_def.mutable_attr(); - - attr["output_row_partition_type"].set_s("row_lengths"); - INFER_OK(op, "?;?", "[?];[?];[?];[?]"); - INFER_OK(op, "[?];?", "[?];[d0_0];[?];[?]"); - INFER_OK(op, "[?];[]", "[?];[d0_0];[?];[?]"); - INFER_OK(op, "[5];?", "[?];[d0_0];[?];[?]"); - INFER_OK(op, "[5];[]", "[?];[d0_0];[?];[?]"); - INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?"); - INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,2];?"); - INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1]"); - - attr["output_row_partition_type"].set_s("row_splits"); - INFER_OK(op, "?;?", "[?];[?];[?];[?]"); - INFER_OK(op, "[?];?", "[?];[?];[?];[?]"); - INFER_OK(op, "[?];[]", "[?];[?];[?];[?]"); - INFER_OK(op, "[5];?", "[?];[6];[?];[?]"); - INFER_OK(op, "[5];[]", "[?];[6];[?];[?]"); -} - -} // namespace -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/wordpiece_tokenizer.cc b/tensorflow_text/core/kernels/wordpiece_tokenizer.cc deleted file mode 100644 index fd9adad5a..000000000 --- a/tensorflow_text/core/kernels/wordpiece_tokenizer.cc +++ /dev/null @@ -1,246 +0,0 @@ -// Copyright 2026 TF.Text Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow_text/core/kernels/wordpiece_tokenizer.h" - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "icu4c/source/common/unicode/utf8.h" - -namespace tensorflow { -namespace text { - -namespace { - -LookupStatus Lookup(int byte_start, int byte_end, - const absl::string_view& token, - const std::string& suffix_indicator, - const WordpieceVocab* vocab_map, bool* in_vocab) { - int byte_len = byte_end - byte_start; - absl::string_view substr(token.data() + byte_start, byte_len); - return vocab_map->Contains( - byte_start > 0 ? absl::StrCat(suffix_indicator, substr) : substr, - in_vocab); -} - -// Sets byte_end to the longest byte sequence which: -// 1) is a proper UTF8 sequence -// 2) is in the vocab OR if split_unknown_characters is true, is a single -// UTF8 character. -// If no match is found, found_match is set to false. -LookupStatus LongestMatchStartingAt( - int byte_start, const absl::string_view& token, - const std::string& suffix_indicator, const int max_chars_per_subtoken, - bool split_unknown_characters, const WordpieceVocab* vocab_map, - int* byte_end, bool* found_match, bool* match_is_unknown_character) { - *match_is_unknown_character = false; - *found_match = false; - const char* token_bytes = token.data(); - std::vector<int32_t> byte_ends; - int upper_limit = token.length(); - - for (int32_t i = byte_start; i < token.length();) { - UChar32 c; - U8_NEXT(token_bytes, i, upper_limit, c); - byte_ends.push_back(i); - if (max_chars_per_subtoken > 0 && - byte_ends.size() == max_chars_per_subtoken) { - // If the max bytes of a subtoken is known, do not search beyond that - // length. - break; - } - } - int n = byte_ends.size(); - for (int i = n - 1; i >= 0; i--) { - bool in_vocab; - auto status = Lookup(byte_start, byte_ends[i], token, suffix_indicator, - vocab_map, &in_vocab); - if (!status.success) return status; - if (in_vocab) { - *byte_end = byte_ends[i]; - *found_match = true; - return LookupStatus::OK(); - } - if (i == 0 && split_unknown_characters) { - *byte_end = byte_ends[0]; - *found_match = true; - *match_is_unknown_character = true; - return LookupStatus::OK(); - } - } - return LookupStatus::OK(); -} - -// Sets the outputs 'begin_offset', 'end_offset' and 'num_word_pieces' when no -// token is found. -LookupStatus NoTokenFound(const absl::string_view& token, - bool use_unknown_token, - const std::string& unknown_token, - std::vector<std::string>* subwords, - std::vector<int>* begin_offset, - std::vector<int>* end_offset, int* num_word_pieces) { - begin_offset->push_back(0); - if (use_unknown_token) { - subwords->push_back(unknown_token); - end_offset->push_back(token.length()); - } else { - subwords->emplace_back(token.data(), token.length()); - end_offset->push_back(token.length()); - } - ++(*num_word_pieces); - - return LookupStatus::OK(); -} - -// When a subword is found, this helper function will add the outputs to -// 'subwords', 'begin_offset' and 'end_offset'. -void AddWord(const absl::string_view& token, int byte_start, int byte_end, - const std::string& suffix_indicator, - std::vector<std::string>* subwords, std::vector<int>* begin_offset, - std::vector<int>* end_offset) { - begin_offset->push_back(byte_start); - int len = byte_end - byte_start; - - if (byte_start > 0) { - // Prepend suffix_indicator if the token is within a word. - subwords->push_back(::absl::StrCat( - suffix_indicator, absl::string_view(token.data() + byte_start, len))); - } else { - subwords->emplace_back(token.data(), len); - } - end_offset->push_back(byte_end); -} - -// Adds a single unknown character subword, found when split_unknown_characters -// is true. -void AddUnknownCharacter(const absl::string_view& token, int byte_start, - int byte_end, const std::string& suffix_indicator, - bool use_unknown_token, - const std::string& unknown_token, - std::vector<std::string>* subwords, - std::vector<int>* begin_offset, - std::vector<int>* end_offset) { - begin_offset->push_back(byte_start); - end_offset->push_back(byte_end); - int len = byte_end - byte_start; - if (use_unknown_token) { - if (byte_start > 0) { - // Prepend suffix_indicator if the character is within a word. - subwords->push_back(::absl::StrCat(suffix_indicator, unknown_token)); - } else { - subwords->push_back(unknown_token); - } - } else { - if (byte_start > 0) { - // Prepend suffix_indicator if the character is within a word. - subwords->push_back(::absl::StrCat( - suffix_indicator, absl::string_view(token.data() + byte_start, len))); - } else { - subwords->emplace_back(token.data(), len); - } - } -} - -LookupStatus TokenizeL2RGreedy( - const absl::string_view& token, const int max_bytes_per_token, - const int max_chars_per_subtoken, const std::string& suffix_indicator, - bool use_unknown_token, const std::string& unknown_token, - bool split_unknown_characters, const WordpieceVocab* vocab_map, - std::vector<std::string>* subwords, std::vector<int>* begin_offset, - std::vector<int>* end_offset, int* num_word_pieces) { - std::vector<std::string> candidate_subwords; - std::vector<int> candidate_begin_offsets; - std::vector<int> candidate_end_offsets; - const int token_len = token.length(); - for (int byte_start = 0; byte_start < token_len;) { - int byte_end; - bool found_subword; - bool match_is_unknown_character; - auto status = LongestMatchStartingAt( - byte_start, token, suffix_indicator, max_chars_per_subtoken, - split_unknown_characters, vocab_map, &byte_end, &found_subword, - &match_is_unknown_character); - if (!status.success) return status; - if (found_subword) { - if (match_is_unknown_character) { - AddUnknownCharacter(token, byte_start, byte_end, suffix_indicator, - use_unknown_token, unknown_token, - &candidate_subwords, &candidate_begin_offsets, - &candidate_end_offsets); - } else { - AddWord(token, byte_start, byte_end, suffix_indicator, - &candidate_subwords, &candidate_begin_offsets, - &candidate_end_offsets); - } - byte_start = byte_end; - } else { - return NoTokenFound(token, use_unknown_token, unknown_token, subwords, - begin_offset, end_offset, num_word_pieces); - } - } - - subwords->insert(subwords->end(), candidate_subwords.begin(), - candidate_subwords.end()); - begin_offset->insert(begin_offset->end(), candidate_begin_offsets.begin(), - candidate_begin_offsets.end()); - end_offset->insert(end_offset->end(), candidate_end_offsets.begin(), - candidate_end_offsets.end()); - *num_word_pieces += candidate_subwords.size(); - return LookupStatus::OK(); -} - -} // namespace - -LookupStatus WordpieceTokenize( - const absl::string_view& token, const int max_bytes_per_token, - const int max_chars_per_subtoken, const std::string& suffix_indicator, - bool use_unknown_token, const std::string& unknown_token, - bool split_unknown_characters, const WordpieceVocab* vocab_map, - std::vector<std::string>* subwords, std::vector<int>* begin_offset, - std::vector<int>* end_offset, int* num_word_pieces) { - int token_len = token.size(); - if (token_len > max_bytes_per_token) { - begin_offset->push_back(0); - *num_word_pieces = 1; - if (use_unknown_token) { - end_offset->push_back(unknown_token.size()); - subwords->emplace_back(unknown_token); - } else { - subwords->emplace_back(token); - end_offset->push_back(token.size()); - } - return LookupStatus::OK(); - } - return TokenizeL2RGreedy(token, max_bytes_per_token, max_chars_per_subtoken, - suffix_indicator, use_unknown_token, unknown_token, - split_unknown_characters, vocab_map, subwords, - begin_offset, end_offset, num_word_pieces); -} - -LookupStatus WordpieceTokenize( - const absl::string_view& token, const int max_bytes_per_token, - const std::string& suffix_indicator, bool use_unknown_token, - const std::string& unknown_token, const WordpieceVocab* vocab_map, - std::vector<std::string>* subwords, std::vector<int>* begin_offset, - std::vector<int>* end_offset, int* num_word_pieces) { - return WordpieceTokenize(token, max_bytes_per_token, - /* max_chars_per_subtoken= */ 0, suffix_indicator, - use_unknown_token, unknown_token, - /* split_unknown_characters= */ false, vocab_map, - subwords, begin_offset, end_offset, num_word_pieces); -} - -} // namespace text -} // namespace tensorflow diff --git a/tensorflow_text/core/kernels/wordpiece_tokenizer.h b/tensorflow_text/core/kernels/wordpiece_tokenizer.h index c173497ee..69913fce0 100644 --- a/tensorflow_text/core/kernels/wordpiece_tokenizer.h +++ b/tensorflow_text/core/kernels/wordpiece_tokenizer.h @@ -12,52 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_TEXT_CORE_KERNELS_WORDPIECE_TOKENIZER_H_ -#define TENSORFLOW_TEXT_CORE_KERNELS_WORDPIECE_TOKENIZER_H_ +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WORDPIECE_TOKENIZER_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WORDPIECE_TOKENIZER_H_ -#include <string> -#include <utility> -#include <vector> +#include "tensorflow/core/kernels/text/wordpiece_tokenizer.h" -#include "absl/strings/string_view.h" - -namespace tensorflow { -namespace text { - -struct LookupStatus { - LookupStatus() : error_msg(""), success(true) {} - LookupStatus(std::string msg) : error_msg(std::move(msg)), success(false) {} - std::string error_msg; - bool success; - - static LookupStatus OK() { return LookupStatus(); } -}; - -class WordpieceVocab { - public: - virtual ~WordpieceVocab() {} - virtual LookupStatus Contains(const absl::string_view key, - bool* value) const = 0; -}; - -LookupStatus WordpieceTokenize( - const absl::string_view& token, const int max_bytes_per_token, - const int max_chars_per_subtoken, const std::string& suffix_indicator, - bool use_unknown_token, const std::string& unknown_token, - bool split_unknown_characters, const WordpieceVocab* vocab_map, - std::vector<std::string>* subwords, std::vector<int>* begin_offset, - std::vector<int>* end_offset, int* num_word_pieces); - -// As above but with `max_bytes_per_subtoken` unknown, -// and split_unknown_characters=false. (For backwards compatability.) -LookupStatus WordpieceTokenize( - const absl::string_view& token, const int max_bytes_per_token, - const std::string& suffix_indicator, bool use_unknown_token, - const std::string& unknown_token, const WordpieceVocab* vocab_map, - std::vector<std::string>* subwords, std::vector<int>* begin_offset, - std::vector<int>* end_offset, int* num_word_pieces); - -} // namespace text -} // namespace tensorflow - -#endif // TENSORFLOW_TEXT_CORE_KERNELS_WORDPIECE_TOKENIZER_H_ +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_WORDPIECE_TOKENIZER_H_ diff --git a/tensorflow_text/core/pybinds/BUILD b/tensorflow_text/core/pybinds/BUILD index a8f5057d2..2e8e19881 100644 --- a/tensorflow_text/core/pybinds/BUILD +++ b/tensorflow_text/core/pybinds/BUILD @@ -5,11 +5,14 @@ load("//tensorflow_text:tftext.bzl", "if_pywrap", "pybind_extension", "pywrap_bi licenses(["notice"]) -package(default_visibility = [ - "//nlp/sage/nlu/features/python:__pkg__", - "//nlp/semantic_parsing/learning/neural/portable/tools/release:__pkg__", - "//tensorflow_text:__subpackages__", -]) +package( + default_applicable_licenses = ["//tensorflow_text:license"], + default_visibility = [ + "//nlp/sage/nlu/features/python:__pkg__", + "//nlp/semantic_parsing/learning/neural/portable/tools/release:__pkg__", + "//tensorflow_text:__subpackages__", + ], +) pybind_extension( name = "tflite_registrar", @@ -17,10 +20,19 @@ pybind_extension( "tflite_registrar.cc", ], deps = [ + "@org_tensorflow//tensorflow/core/kernels/text:byte_splitter_tflite", + "@org_tensorflow//tensorflow/core/kernels/text:fast_bert_normalizer_tflite", + "@org_tensorflow//tensorflow/core/kernels/text:fast_wordpiece_tokenizer_tflite", + "@org_tensorflow//tensorflow/core/kernels/text:ngrams_tflite", + "@org_tensorflow//tensorflow/core/kernels/text:ragged_tensor_to_tensor_tflite", + "@org_tensorflow//tensorflow/core/kernels/text:round_robin_trimmer_tflite", + "@org_tensorflow//tensorflow/core/kernels/text:sentence_fragmenter_v2_tflite", + "@org_tensorflow//tensorflow/core/kernels/text:utf8_binarize_tflite", + "@org_tensorflow//tensorflow/core/kernels/text:whitespace_tokenizer_tflite", + "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:py_tflite_registerer", # lite:framework tensorflow dep, # lite/c:common tensorflow dep, # lite/kernels:builtin_ops tensorflow dep, - "//tensorflow_text/core/kernels:tflite_ops", ], ) @@ -30,7 +42,7 @@ pybind_extension( copts = ["-fexceptions"], features = ["-use_header_modules"], deps = [ - "//tensorflow_text/core/kernels:fast_bert_normalizer_model_builder", + "@org_tensorflow//tensorflow/core/kernels/text:fast_bert_normalizer_model_builder", ], ) @@ -61,7 +73,7 @@ pybind_extension( ], features = ["-use_header_modules"], deps = [ - "//tensorflow_text/core/kernels:fast_wordpiece_tokenizer_model_builder", + "@org_tensorflow//tensorflow/core/kernels/text:fast_wordpiece_tokenizer_model_builder", ], ) @@ -88,7 +100,7 @@ pybind_extension( "//tensorflow_text:__subpackages__", ], deps = [ - "//tensorflow_text/core/kernels:phrase_tokenizer_model_builder", + "@org_tensorflow//tensorflow/core/kernels/text:phrase_tokenizer_model_builder", ], ) @@ -111,7 +123,7 @@ pybind_extension( copts = ["-fexceptions"], features = ["-use_header_modules"], deps = [ - "//tensorflow_text/core/kernels/sentencepiece:model_converter", + "@org_tensorflow//tensorflow/core/kernels/text/sentencepiece:model_converter", ], ) @@ -121,7 +133,7 @@ pybind_extension( copts = ["-fexceptions"], features = ["-use_header_modules"], deps = [ - "//tensorflow_text/core/kernels:whitespace_tokenizer_config_builder", + "@org_tensorflow//tensorflow/core/kernels/text:whitespace_tokenizer_config_builder", ], ) diff --git a/tensorflow_text/core/pybinds/pywrap_fast_bert_normalizer_model_builder.cc b/tensorflow_text/core/pybinds/pywrap_fast_bert_normalizer_model_builder.cc index d339ec6d6..9d0fe72bd 100644 --- a/tensorflow_text/core/pybinds/pywrap_fast_bert_normalizer_model_builder.cc +++ b/tensorflow_text/core/pybinds/pywrap_fast_bert_normalizer_model_builder.cc @@ -16,7 +16,7 @@ #include "include/pybind11/pybind11.h" #include "include/pybind11/stl.h" -#include "tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.h" +#include "tensorflow/core/kernels/text/fast_bert_normalizer_model_builder.h" namespace tensorflow { namespace text { diff --git a/tensorflow_text/core/pybinds/pywrap_fast_wordpiece_tokenizer_model_builder.cc b/tensorflow_text/core/pybinds/pywrap_fast_wordpiece_tokenizer_model_builder.cc index 573250db0..4e4ddda10 100644 --- a/tensorflow_text/core/pybinds/pywrap_fast_wordpiece_tokenizer_model_builder.cc +++ b/tensorflow_text/core/pybinds/pywrap_fast_wordpiece_tokenizer_model_builder.cc @@ -16,7 +16,7 @@ #include "include/pybind11/pybind11.h" #include "include/pybind11/stl.h" -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.h" +#include "tensorflow/core/kernels/text/fast_wordpiece_tokenizer_model_builder.h" namespace tensorflow { namespace text { diff --git a/tensorflow_text/core/pybinds/pywrap_model_converter.cc b/tensorflow_text/core/pybinds/pywrap_model_converter.cc index 73a932805..70900269d 100644 --- a/tensorflow_text/core/pybinds/pywrap_model_converter.cc +++ b/tensorflow_text/core/pybinds/pywrap_model_converter.cc @@ -16,7 +16,7 @@ #include "include/pybind11/pybind11.h" #include "include/pybind11/stl.h" -#include "tensorflow_text/core/kernels/sentencepiece/model_converter.h" +#include "tensorflow/core/kernels/text/sentencepiece/model_converter.h" namespace tensorflow { namespace text { diff --git a/tensorflow_text/core/pybinds/pywrap_phrase_tokenizer_model_builder.cc b/tensorflow_text/core/pybinds/pywrap_phrase_tokenizer_model_builder.cc index 1221f92a9..225f2df4f 100644 --- a/tensorflow_text/core/pybinds/pywrap_phrase_tokenizer_model_builder.cc +++ b/tensorflow_text/core/pybinds/pywrap_phrase_tokenizer_model_builder.cc @@ -17,7 +17,7 @@ #include <stdexcept> #include "include/pybind11/pybind11.h" -#include "tensorflow_text/core/kernels/phrase_tokenizer_model_builder.h" +#include "tensorflow/core/kernels/text/phrase_tokenizer_model_builder.h" namespace tensorflow { namespace text { diff --git a/tensorflow_text/core/pybinds/pywrap_whitespace_tokenizer_config_builder.cc b/tensorflow_text/core/pybinds/pywrap_whitespace_tokenizer_config_builder.cc index 3266e2f77..f4b6c5f44 100644 --- a/tensorflow_text/core/pybinds/pywrap_whitespace_tokenizer_config_builder.cc +++ b/tensorflow_text/core/pybinds/pywrap_whitespace_tokenizer_config_builder.cc @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include <stdexcept> #include <iostream> +#include <stdexcept> + #include "include/pybind11/pybind11.h" #include "include/pybind11/stl.h" -#include "tensorflow_text/core/kernels/whitespace_tokenizer_config_builder.h" +#include "tensorflow/core/kernels/text/whitespace_tokenizer_config_builder.h" namespace tensorflow { namespace text { diff --git a/tensorflow_text/core/pybinds/tflite_registrar.cc b/tensorflow_text/core/pybinds/tflite_registrar.cc index 08a7fbfae..7619e2dff 100644 --- a/tensorflow_text/core/pybinds/tflite_registrar.cc +++ b/tensorflow_text/core/pybinds/tflite_registrar.cc @@ -14,16 +14,16 @@ #include "include/pybind11/pybind11.h" #include "include/pybind11/pytypes.h" -#include "tensorflow_text/core/kernels/byte_splitter_tflite.h" -#include "tensorflow_text/core/kernels/fast_bert_normalizer_tflite.h" -#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer_tflite.h" -#include "tensorflow_text/core/kernels/ngrams_tflite.h" -#include "tensorflow_text/core/kernels/ragged_tensor_to_tensor_tflite.h" -#include "tensorflow_text/core/kernels/round_robin_trimmer_tflite.h" -#include "tensorflow_text/core/kernels/sentence_fragmenter_v2_tflite.h" -#include "tensorflow_text/core/kernels/sentencepiece/py_tflite_registerer.h" -#include "tensorflow_text/core/kernels/utf8_binarize_tflite.h" -#include "tensorflow_text/core/kernels/whitespace_tokenizer_tflite.h" +#include "tensorflow/core/kernels/text/byte_splitter_tflite.h" +#include "tensorflow/core/kernels/text/fast_bert_normalizer_tflite.h" +#include "tensorflow/core/kernels/text/fast_wordpiece_tokenizer_tflite.h" +#include "tensorflow/core/kernels/text/ngrams_tflite.h" +#include "tensorflow/core/kernels/text/ragged_tensor_to_tensor_tflite.h" +#include "tensorflow/core/kernels/text/round_robin_trimmer_tflite.h" +#include "tensorflow/core/kernels/text/sentence_fragmenter_v2_tflite.h" +#include "tensorflow/core/kernels/text/sentencepiece/py_tflite_registerer.h" +#include "tensorflow/core/kernels/text/utf8_binarize_tflite.h" +#include "tensorflow/core/kernels/text/whitespace_tokenizer_tflite.h" PYBIND11_MODULE(tflite_registrar, m) { m.doc() = R"pbdoc(