From fcff61f085bd0984430800c446c2e56c39241e1e Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 25 Jul 2019 11:06:13 -0700
Subject: [PATCH] Create a C++ string-ngrams op.

PiperOrigin-RevId: 259982106
---
 .../base_api/api_def_StringNGrams.pbtxt       |  69 +++
 .../python_api/api_def_StringNGrams.pbtxt     |   4 +
 tensorflow/core/kernels/BUILD                 |  25 +
 tensorflow/core/kernels/string_ngrams_op.cc   | 201 +++++++
 .../core/kernels/string_ngrams_op_test.cc     | 554 ++++++++++++++++++
 tensorflow/core/ops/string_ops.cc             |  22 +
 tensorflow/python/ops/ragged/BUILD            |  12 +
 .../python/ops/ragged/ragged_string_ops.py    | 137 +++++
 .../ops/ragged/string_ngrams_op_test.py       | 250 ++++++++
 .../api/golden/v1/tensorflow.raw_ops.pbtxt    |   4 +
 .../api/golden/v1/tensorflow.strings.pbtxt    |   4 +
 .../api/golden/v2/tensorflow.raw_ops.pbtxt    |   4 +
 .../api/golden/v2/tensorflow.strings.pbtxt    |   4 +
 13 files changed, 1290 insertions(+)
 create mode 100644 tensorflow/core/api_def/base_api/api_def_StringNGrams.pbtxt
 create mode 100644 tensorflow/core/api_def/python_api/api_def_StringNGrams.pbtxt
 create mode 100644 tensorflow/core/kernels/string_ngrams_op.cc
 create mode 100644 tensorflow/core/kernels/string_ngrams_op_test.cc
 create mode 100644 tensorflow/python/ops/ragged/string_ngrams_op_test.py

diff --git a/tensorflow/core/api_def/base_api/api_def_StringNGrams.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringNGrams.pbtxt
new file mode 100644
index 00000000000..d3d1a01ed37
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StringNGrams.pbtxt
@@ -0,0 +1,69 @@
+op {
+  graph_op_name: "StringNGrams"
+  in_arg {
+    name: "data"
+    description: <<END
+The values tensor of the ragged string tensor to make ngrams out of. Must be a
+1D string tensor.
+END
+  }
+  in_arg {
+    name: "data_splits"
+    description: <<END
+The splits tensor of the ragged string tensor to make ngrams out of.
+END
+  }
+  out_arg {
+    name: "ngrams"
+    description: <<END
+The values tensor of the output ngrams ragged tensor.
+END
+  }
+  out_arg {
+    name: "ngrams_splits"
+    description: <<END
+The splits tensor of the output ngrams ragged tensor.
+END
+  }
+  attr {
+    name: "separator"
+    description: <<END
+The string to append between elements of the token. Use "" for no separator.
+END
+  }
+  attr {
+    name: "ngram_widths"
+    description: <<END
+The sizes of the ngrams to create.
+END
+  }
+  attr {
+    name: "left_pad"
+    description: <<END
+The string to use to pad the left side of the ngram sequence. Only used if
+pad_width != 0.
+END
+  }
+  attr {
+    name: "right_pad"
+    description: <<END
+The string to use to pad the right side of the ngram sequence. Only used if
+pad_width != 0.
+END
+}
+  attr {
+    name: "pad_width"
+    description: <<END
+The number of padding elements to add to each side of each
+sequence. Note that padding will never be greater than 'ngram_widths'-1
+regardless of this value. If `pad_width=-1`, then add `max(ngram_widths)-1`
+elements.
+END
+  }
+  summary: "Creates ngrams from ragged string data."
+  description: <<END
+This op accepts a ragged tensor with 1 ragged dimension containing only
+strings and outputs a ragged tensor with 1 ragged dimension containing ngrams
+of that string, joined along the innermost axis.
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringNGrams.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringNGrams.pbtxt
new file mode 100644
index 00000000000..acefd9ba024
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringNGrams.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "StringNGrams"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 572afde42d9..4b084033efe 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -5317,6 +5317,7 @@ cc_library(
         ":string_join_op",
         ":string_length_op",
         ":string_lower_op",
+        ":string_ngrams_op",
         ":string_split_op",
         ":string_strip_op",
         ":string_to_hash_bucket_op",
@@ -5457,6 +5458,30 @@ tf_cc_test(
     ],
 )
 
+tf_kernel_library(
+    name = "string_ngrams_op",
+    srcs = ["string_ngrams_op.cc"],
+    deps = STRING_DEPS + [
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+tf_cc_test(
+    name = "string_ngrams_op_test",
+    srcs = ["string_ngrams_op_test.cc"],
+    deps = [
+        ":ops_testutil",
+        ":ops_util",
+        ":string_ngrams_op",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
 tf_kernel_library(
     name = "string_strip_op",
     prefix = "string_strip_op",
diff --git a/tensorflow/core/kernels/string_ngrams_op.cc b/tensorflow/core/kernels/string_ngrams_op.cc
new file mode 100644
index 00000000000..37a7aa956d0
--- /dev/null
+++ b/tensorflow/core/kernels/string_ngrams_op.cc
@@ -0,0 +1,201 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <locale>
+#include <string>
+
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace text {
+
+namespace {
+template <typename SPLITS_TYPE>
+class StringNGramsOp : public tensorflow::OpKernel {
+ public:
+  explicit StringNGramsOp(tensorflow::OpKernelConstruction* context)
+      : tensorflow::OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("separator", &separator_));
+    OP_REQUIRES_OK(context, context->GetAttr("ngram_widths", &ngram_widths_));
+    OP_REQUIRES_OK(context, context->GetAttr("left_pad", &left_pad_));
+    OP_REQUIRES_OK(context, context->GetAttr("right_pad", &right_pad_));
+    OP_REQUIRES_OK(context, context->GetAttr("pad_width", &pad_width_));
+    OP_REQUIRES_OK(context, context->GetAttr("preserve_short_sequences",
+                                             &preserve_short_));
+  }
+
+  int get_pad_width(const int ngram_width) const {
+    // Ngrams can be padded with either a fixed pad width or a dynamic pad
+    // width depending on the 'pad_width' arg, but in no case should the padding
+    // ever be wider than 'ngram_width' - 1.
+    return std::min(pad_width_ < 0 ? ngram_width - 1 : pad_width_,
+                    ngram_width - 1);
+  }
+
+  int get_num_ngrams(const int length, const int ngram_width) const {
+    int pad_width = get_pad_width(ngram_width);
+    return std::max(0, ((length + 2 * pad_width) - ngram_width) + 1);
+  }
+
+  void Compute(tensorflow::OpKernelContext* context) override {
+    const tensorflow::Tensor* data;
+    OP_REQUIRES_OK(context, context->input("data", &data));
+    const auto& input_data = data->flat<string>().data();
+
+    const tensorflow::Tensor* splits;
+    OP_REQUIRES_OK(context, context->input("data_splits", &splits));
+    const auto& splits_vec = splits->flat<SPLITS_TYPE>();
+
+    // If there is no data or size, return an empty RT.
+    if (data->flat<string>().size() == 0 || splits_vec.size() == 0) {
+      tensorflow::Tensor* empty;
+      OP_REQUIRES_OK(context,
+                     context->allocate_output(0, data->shape(), &empty));
+      OP_REQUIRES_OK(context,
+                     context->allocate_output(1, splits->shape(), &empty));
+      return;
+    }
+
+    int num_batch_items = splits_vec.size() - 1;
+    tensorflow::Tensor* ngrams_splits;
+    OP_REQUIRES_OK(
+        context, context->allocate_output(1, splits->shape(), &ngrams_splits));
+    auto ngrams_splits_data = ngrams_splits->flat<SPLITS_TYPE>().data();
+
+    ngrams_splits_data[0] = 0;
+    for (int i = 1; i <= num_batch_items; ++i) {
+      int length = splits_vec(i) - splits_vec(i - 1);
+      int num_ngrams = 0;
+      for (int ngram_width : ngram_widths_)
+        num_ngrams += get_num_ngrams(length, ngram_width);
+      if (preserve_short_ && length > 0 && num_ngrams == 0) {
+        num_ngrams = 1;
+      }
+      ngrams_splits_data[i] = ngrams_splits_data[i - 1] + num_ngrams;
+    }
+
+    tensorflow::Tensor* ngrams;
+    OP_REQUIRES_OK(
+        context,
+        context->allocate_output(
+            0, TensorShape({ngrams_splits_data[num_batch_items]}), &ngrams));
+    auto ngrams_data = ngrams->flat<string>().data();
+
+    for (int i = 0; i < num_batch_items; ++i) {
+      auto data_start = &input_data[splits_vec(i)];
+      int output_start_idx = ngrams_splits_data[i];
+      for (int ngram_width : ngram_widths_) {
+        auto output_start = &ngrams_data[output_start_idx];
+        int length = splits_vec(i + 1) - splits_vec(i);
+        int num_ngrams = get_num_ngrams(length, ngram_width);
+        CreateNgrams(data_start, output_start, num_ngrams, ngram_width);
+        output_start_idx += num_ngrams;
+      }
+      // If we're preserving short sequences, check to see if no sequence was
+      // generated by comparing the current output start idx to the original
+      // one (ngram_splits_data). If no ngrams were generated, then they will
+      // be equal (since we increment output_start_idx by num_ngrams every
+      // time we create a set of ngrams.)
+      if (preserve_short_ && output_start_idx == ngrams_splits_data[i]) {
+        int data_length = splits_vec(i + 1) - splits_vec(i);
+        // One legitimate reason to not have any ngrams when preserve_short_
+        // is true is if the sequence itself is empty. In that case, move on.
+        if (data_length == 0) {
+          continue;
+        }
+        // We don't have to worry about dynamic padding sizes here: if padding
+        // was dynamic, every sequence would have had sufficient padding to
+        // generate at least one ngram.
+        int ngram_width = data_length + 2 * pad_width_;
+        auto output_start = &ngrams_data[output_start_idx];
+        int num_ngrams = 1;
+        CreateNgrams(data_start, output_start, num_ngrams, ngram_width);
+      }
+    }
+  }
+
+  void CreateNgrams(const string* data, string* output, int num_ngrams,
+                    int ngram_width) const {
+    for (int ngram_index = 0; ngram_index < num_ngrams; ++ngram_index) {
+      int pad_width = get_pad_width(ngram_width);
+      int left_padding = std::max(0, pad_width - ngram_index);
+      int right_padding =
+          std::max(0, pad_width - (num_ngrams - (ngram_index + 1)));
+      int num_tokens = ngram_width - (left_padding + right_padding);
+      int data_start_index = left_padding > 0 ? 0 : ngram_index - pad_width;
+
+      // Calculate the total expected size of the ngram so we can reserve the
+      // correct amount of space in the string.
+      int ngram_size = 0;
+      // Size of the left padding.
+      ngram_size += left_padding * left_pad_.length();
+      // Size of the tokens.
+      for (int n = 0; n < num_tokens; ++n) {
+        ngram_size += data[data_start_index + n].length();
+      }
+      // Size of the right padding.
+      ngram_size += right_padding * right_pad_.length();
+      // Size of the separators.
+      int num_separators = left_padding + right_padding + num_tokens - 1;
+      ngram_size += num_separators * separator_.length();
+
+      // Build the ngram.
+      string* ngram = &output[ngram_index];
+      ngram->reserve(ngram_size);
+      for (int n = 0; n < left_padding; ++n) {
+        *ngram += left_pad_;
+        *ngram += separator_;
+      }
+      for (int n = 0; n < num_tokens - 1; ++n) {
+        *ngram += data[data_start_index + n];
+        *ngram += separator_;
+      }
+      *ngram += data[data_start_index + num_tokens - 1];
+      for (int n = 0; n < right_padding; ++n) {
+        *ngram += separator_;
+        *ngram += right_pad_;
+      }
+
+      // In debug mode only: validate that we've reserved enough space for the
+      // ngram.
+      DCHECK_EQ(ngram_size, ngram->size());
+    }
+  }
+
+  string separator_;
+  string left_pad_;
+  string right_pad_;
+  bool use_pad_;
+  bool extend_pad_;
+  bool preserve_short_;
+
+  std::vector<int> ngram_widths_;
+  int pad_width_;
+};
+
+}  // namespace
+REGISTER_KERNEL_BUILDER(Name("StringNGrams")
+                            .Device(tensorflow::DEVICE_CPU)
+                            .TypeConstraint<int32>("Tsplits"),
+                        StringNGramsOp<int32>);
+REGISTER_KERNEL_BUILDER(Name("StringNGrams")
+                            .Device(tensorflow::DEVICE_CPU)
+                            .TypeConstraint<int64>("Tsplits"),
+                        StringNGramsOp<int64>);
+
+}  // namespace text
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_ngrams_op_test.cc b/tensorflow/core/kernels/string_ngrams_op_test.cc
new file mode 100644
index 00000000000..afd1700c9ab
--- /dev/null
+++ b/tensorflow/core/kernels/string_ngrams_op_test.cc
@@ -0,0 +1,554 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace text {
+
+using tensorflow::FakeInput;
+using tensorflow::NodeDefBuilder;
+using tensorflow::Status;
+using tensorflow::TensorShape;
+
+class NgramKernelTest : public tensorflow::OpsTestBase {
+ public:
+  void MakeOp(string separator, std::vector<int> ngram_width, string left_pad,
+              string right_pad, int pad_width, bool preserve) {
+    TF_ASSERT_OK(NodeDefBuilder("tested_op", "StringNGrams")
+                     .Attr("separator", separator)
+                     .Attr("ngram_widths", ngram_width)
+                     .Attr("left_pad", left_pad)
+                     .Attr("right_pad", right_pad)
+                     .Attr("pad_width", pad_width)
+                     .Attr("preserve_short_sequences", preserve)
+                     .Input(FakeInput())
+                     .Input(FakeInput())
+                     .Finalize(node_def()));
+    TF_ASSERT_OK(InitOp());
+  }
+
+  void assert_string_equal(const std::vector<string> &expected,
+                           const Tensor &value) {
+    Tensor expected_tensor(allocator(), DT_STRING,
+                           TensorShape({static_cast<int64>(expected.size())}));
+    test::FillValues<string>(&expected_tensor, expected);
+    test::ExpectTensorEqual<string>(expected_tensor, value);
+  }
+  void assert_int64_equal(const std::vector<int64> &expected,
+                          const Tensor &value) {
+    Tensor expected_tensor(allocator(), DT_INT64,
+                           TensorShape({static_cast<int64>(expected.size())}));
+    test::FillValues<int64>(&expected_tensor, expected);
+    test::ExpectTensorEqual<int64>(expected_tensor, value);
+  }
+};
+
+TEST_F(NgramKernelTest, TestPaddedTrigrams) {
+  MakeOp("|", {3}, "LP", "RP", -1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(                              //
+      {"LP|LP|a", "LP|a|b", "a|b|c", "b|c|d", "c|d|RP", "d|RP|RP",  // 0
+       "LP|LP|e", "LP|e|f", "e|f|RP", "f|RP|RP"});                  // 1
+  std::vector<int64> expected_splits({0, 6, 10});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestPaddedBigramsAndTrigrams) {
+  MakeOp("|", {2, 3}, "LP", "RP", -1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(
+      {"LP|a", "a|b", "b|c", "c|d", "d|RP", "LP|LP|a", "LP|a|b", "a|b|c",
+       "b|c|d", "c|d|RP", "d|RP|RP",                                       // 0
+       "LP|e", "e|f", "f|RP", "LP|LP|e", "LP|e|f", "e|f|RP", "f|RP|RP"});  // 1
+  std::vector<int64> expected_splits({0, 11, 18});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestPaddedBigrams) {
+  MakeOp("|", {2}, "LP", "RP", -1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(       //
+      {"LP|a", "a|b", "b|c", "c|d", "d|RP",  // 0
+       "LP|e", "e|f", "f|RP"});              // 1
+  std::vector<int64> expected_splits({0, 5, 8});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestPaddingIsAtMostNGramSizeMinus1) {
+  MakeOp("|", {2}, "LP", "RP", 4, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(       //
+      {"LP|a", "a|b", "b|c", "c|d", "d|RP",  // 0
+       "LP|e", "e|f", "f|RP"});              // 1
+  std::vector<int64> expected_splits({0, 5, 8});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestPaddedUnigramAndBigrams) {
+  MakeOp("|", {1, 2}, "LP", "RP", -1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(                           //
+      {"a", "b", "c", "d", "LP|a", "a|b", "b|c", "c|d", "d|RP",  // 0
+       "e", "f", "LP|e", "e|f", "f|RP"});                        // 1
+  std::vector<int64> expected_splits({0, 9, 14});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestOverlappingPaddedNGrams) {
+  // This test validates that n-grams with both left and right padding in a
+  // single ngram token are created correctly.
+  MakeOp("|", {3}, "LP", "RP", -1, false);
+  // Batch items are:
+  // 0: "a"
+  // 1: "b", "c", "d"
+  // 2: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(                     //
+      {"LP|LP|a", "LP|a|RP", "a|RP|RP",                    // ngrams for elem. 0
+       "LP|LP|b", "LP|b|c", "b|c|d", "c|d|RP", "d|RP|RP",  // ngrams for elem. 1
+       "LP|LP|e", "LP|e|f", "e|f|RP", "f|RP|RP"});         // ngrams for elem. 2
+  std::vector<int64> expected_splits({0, 3, 8, 12});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestOverlappingPaddedMultiCharNGrams) {
+  MakeOp("|", {3}, "LP", "RP", -1, false);
+  // Batch items are:
+  // 0: "a"
+  // 1: "b", "c", "d"
+  // 2: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}),
+                            {"aa", "bb", "cc", "dd", "ee", "ff"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(                              //
+      {"LP|LP|aa", "LP|aa|RP", "aa|RP|RP",                          //
+       "LP|LP|bb", "LP|bb|cc", "bb|cc|dd", "cc|dd|RP", "dd|RP|RP",  //
+       "LP|LP|ee", "LP|ee|ff", "ee|ff|RP", "ff|RP|RP"});            //
+  std::vector<int64> expected_splits({0, 3, 8, 12});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestMultiOverlappingPaddedNGrams) {
+  // This test validates that n-grams with more than 1 padding value on each
+  // side are created correctly.
+  MakeOp("|", {5}, "LP", "RP", -1, false);
+  // Batch items are:
+  // 0: "a"
+  AddInputFromArray<string>(TensorShape({1}), {"a"});
+  AddInputFromArray<int64>(TensorShape({2}), {0, 1});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"LP|LP|LP|LP|a", "LP|LP|LP|a|RP",
+                                       "LP|LP|a|RP|RP", "LP|a|RP|RP|RP",
+                                       "a|RP|RP|RP|RP"});
+  std::vector<int64> expected_splits({0, 5});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedTrigrams) {
+  MakeOp("|", {3}, "", "", 0, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"a|b|c", "b|c|d"});
+  std::vector<int64> expected_splits({0, 2, 2});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedTrigramsWithEmptySequence) {
+  MakeOp("|", {3}, "", "", 0, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 4, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"a|b|c", "b|c|d"});
+  std::vector<int64> expected_splits({0, 2, 2, 2});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedTrigramsWithPreserveShort) {
+  MakeOp("|", {3}, "", "", 0, true);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"a|b|c", "b|c|d", "e|f"});
+  std::vector<int64> expected_splits({0, 2, 3});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedTrigramsWithPreserveShortAndEmptySequence) {
+  MakeOp("|", {3}, "", "", 0, true);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 4, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"a|b|c", "b|c|d", "e|f"});
+  std::vector<int64> expected_splits({0, 2, 2, 3});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedTrigramsAndQuadgramsWithPreserveShort) {
+  MakeOp("|", {4, 3}, "", "", 0, true);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"a|b|c|d", "a|b|c", "b|c|d", "e|f"});
+  std::vector<int64> expected_splits({0, 3, 4});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedBigramsAndTrigrams) {
+  MakeOp("|", {2, 3}, "", "", 0, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(
+      {"a|b", "b|c", "c|d", "a|b|c", "b|c|d", "e|f"});
+  std::vector<int64> expected_splits({0, 5, 6});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedBigramsAndTrigramsWithPreserveShort) {
+  MakeOp("|", {2, 3}, "", "", 0, true);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  // Note that in this case, because the bigram 'e|f' was already generated,
+  // the op will not generate a special preserve_short bigram.
+  std::vector<string> expected_values(
+      {"a|b", "b|c", "c|d", "a|b|c", "b|c|d", "e|f"});
+  std::vector<int64> expected_splits({0, 5, 6});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedTrigramsAndBigramsWithPreserveShort) {
+  MakeOp("|", {3, 2}, "", "", 0, true);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  // Note that in this case, because the bigram 'e|f' was already generated,
+  // the op will not generate a special preserve_short bigram.
+  std::vector<string> expected_values(
+      {"a|b|c", "b|c|d", "a|b", "b|c", "c|d", "e|f"});
+  std::vector<int64> expected_splits({0, 5, 6});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestUnpaddedBigrams) {
+  MakeOp("|", {2}, "", "", 0, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"a|b", "b|c", "c|d", "e|f"});
+  std::vector<int64> expected_splits({0, 3, 4});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestOverlappingUnpaddedNGrams) {
+  MakeOp("|", {3}, "", "", 0, false);
+  // Batch items are:
+  // 0: "a"
+  // 1: "b", "c", "d"
+  // 2: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"b|c|d"});
+  std::vector<int64> expected_splits({0, 0, 1, 1});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestOverlappingUnpaddedNGramsNoOutput) {
+  MakeOp("|", {5}, "", "", 0, false);
+  // Batch items are:
+  // 0: "a"
+  // 1: "b", "c", "d"
+  // 2: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({});
+  std::vector<int64> expected_splits({0, 0, 0, 0});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestSinglyPaddedTrigrams) {
+  MakeOp("|", {3}, "LP", "RP", 1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"LP|a|b", "a|b|c", "b|c|d", "c|d|RP",  //
+                                       "LP|e|f", "e|f|RP"});
+  std::vector<int64> expected_splits({0, 4, 6});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestSinglyPaddedBigrams) {
+  MakeOp("|", {2}, "LP", "RP", 1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"LP|a", "a|b", "b|c", "c|d", "d|RP",  //
+                                       "LP|e", "e|f", "f|RP"});
+  std::vector<int64> expected_splits({0, 5, 8});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestSinglyPaddedBigramsAnd5grams) {
+  MakeOp("|", {2, 5}, "LP", "RP", 1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(                                   //
+      {"LP|a", "a|b", "b|c", "c|d", "d|RP", "LP|a|b|c|d", "a|b|c|d|RP",  //
+       "LP|e", "e|f", "f|RP"});
+  std::vector<int64> expected_splits({0, 7, 10});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestSinglyPadded5gramsWithPreserveShort) {
+  MakeOp("|", {5}, "LP", "RP", 1, true);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(  //
+      {"LP|a|b|c|d", "a|b|c|d|RP",      //
+       "LP|e|f|RP"});
+  std::vector<int64> expected_splits({0, 2, 3});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestOverlappingSinglyPaddedNGrams) {
+  MakeOp("|", {3}, "LP", "RP", 1, false);
+  // Batch items are:
+  // 0: "a"
+  // 1: "b", "c", "d"
+  // 2: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values(
+      {"LP|a|RP",                    // ngrams for elem. 0
+       "LP|b|c", "b|c|d", "c|d|RP",  // ngrams for elem. 1
+       "LP|e|f", "e|f|RP"});         // ngrams for elem. 2
+  std::vector<int64> expected_splits({0, 1, 4, 6});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestOverlappingSinglyPaddedNGramsNoOutput) {
+  MakeOp("|", {5}, "LP", "RP", 1, false);
+  // Batch items are:
+  // 0: "a"
+  // 1: "b", "c", "d"
+  // 2: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({4}), {0, 1, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"LP|b|c|d|RP"});
+  std::vector<int64> expected_splits({0, 0, 1, 1});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestSinglyPaddedUnigrams) {
+  MakeOp("|", {1}, "LP", "RP", 1, false);
+  // Batch items are:
+  // 0: "a", "b", "c", "d"
+  // 1: "e", "f"
+  AddInputFromArray<string>(TensorShape({6}), {"a", "b", "c", "d", "e", "f"});
+  AddInputFromArray<int64>(TensorShape({3}), {0, 4, 6});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({"a", "b", "c", "d", "e", "f"});
+  std::vector<int64> expected_splits({0, 4, 6});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, TestEmptyInput) {
+  MakeOp("|", {1}, "LP", "RP", 3, false);
+  AddInputFromArray<string>(TensorShape({0}), {});
+  AddInputFromArray<int64>(TensorShape({0}), {});
+  TF_ASSERT_OK(RunOpKernel());
+
+  std::vector<string> expected_values({});
+  std::vector<int64> expected_splits({});
+
+  assert_string_equal(expected_values, *GetOutput(0));
+  assert_int64_equal(expected_splits, *GetOutput(1));
+}
+
+TEST_F(NgramKernelTest, ShapeFn) {
+  ShapeInferenceTestOp op("StringNGrams");
+  INFER_OK(op, "?;?", "[?];[?]");
+  INFER_OK(op, "[1];?", "[?];[?]");
+  INFER_OK(op, "[1];[2]", "[?];in1");
+  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
+  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
+}
+
+}  // namespace text
+}  // namespace tensorflow
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 2e07db36531..4d9ad0a56c5 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -365,4 +365,26 @@ REGISTER_OP("UnicodeDecodeWithOffsets")
       return Status::OK();
     });
 
+REGISTER_OP("StringNGrams")
+    .Attr("separator: string")
+    .Attr("ngram_widths: list(int) >= 0")
+    .Attr("left_pad: string")
+    .Attr("right_pad: string")
+    .Attr("pad_width: int")
+    .Attr("preserve_short_sequences: bool")
+    .Attr("Tsplits: {int32, int64} = DT_INT64")
+    .Input("data: string")
+    .Input("data_splits: Tsplits")
+    .Output("ngrams: string")
+    .Output("ngrams_splits: Tsplits")
+    .SetShapeFn([](InferenceContext* c) {
+      c->set_output(0, c->UnknownShapeOfRank(1));
+      ShapeHandle data = c->input(0);
+      TF_RETURN_IF_ERROR(c->WithRank(data, 1, &data));
+      ShapeHandle data_splits = c->input(1);
+      TF_RETURN_IF_ERROR(c->WithRank(data_splits, 1, &data_splits));
+      c->set_output(1, data_splits);
+      return Status::OK();
+    });
+
 }  // namespace tensorflow
diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD
index f1a802b8c7d..1aade2caf3e 100644
--- a/tensorflow/python/ops/ragged/BUILD
+++ b/tensorflow/python/ops/ragged/BUILD
@@ -1070,3 +1070,15 @@ py_test(
         "@absl_py//absl/testing:parameterized",
     ],
 )
+
+py_test(
+    name = "string_ngrams_op_test",
+    size = "small",
+    srcs = ["string_ngrams_op_test.py"],
+    python_version = "PY2",
+    srcs_version = "PY2AND3",
+    deps = [
+        ":ragged_string_ops",
+        "//tensorflow/python:client_testlib",
+    ],
+)
diff --git a/tensorflow/python/ops/ragged/ragged_string_ops.py b/tensorflow/python/ops/ragged/ragged_string_ops.py
index ed52e9a88fa..e4341b8ce4d 100644
--- a/tensorflow/python/ops/ragged/ragged_string_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_string_ops.py
@@ -26,6 +26,7 @@ from tensorflow.python.ops import string_ops
 from tensorflow.python.ops.ragged import ragged_array_ops
 from tensorflow.python.ops.ragged import ragged_math_ops
 from tensorflow.python.ops.ragged import ragged_tensor
+from tensorflow.python.util import compat as util_compat
 from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
 
@@ -650,3 +651,139 @@ def reduce_join(inputs, axis=None, keepdims=None, separator="", name=None):
   return ragged_math_ops.ragged_reduce_aggregate(
       string_ops.reduce_join, string_ops.unsorted_segment_join, inputs, axis,
       keepdims, separator, name or "RaggedSegmentJoin")
+
+
+@tf_export("strings.ngrams")
+def ngrams(data,
+           ngram_width,
+           separator=" ",
+           pad_values=None,
+           padding_width=None,
+           preserve_short_sequences=False,
+           name=None):
+  """Create a tensor of n-grams based on `data`.
+
+  Creates a tensor of n-grams based on `data`. The n-grams are created by
+  joining windows of `width` adjacent strings from the inner axis of `data`
+  using `separator`.
+
+  The input data can be padded on both the start and end of the sequence, if
+  desired, using the `pad_values` argument. If set, `pad_values` should contain
+  either a tuple of strings or a single string; the 0th element of the tuple
+  will be used to pad the left side of the sequence and the 1st element of the
+  tuple will be used to pad the right side of the sequence. The `padding_width`
+  arg controls how many padding values are added to each side; it defaults to
+  `ngram_width-1`.
+
+  If this op is configured to not have padding, or if it is configured to add
+  padding with `padding_width` set to less than ngram_width-1, it is possible
+  that a sequence, or a sequence plus padding, is smaller than the ngram
+  width. In that case, no ngrams will be generated for that sequence. This can
+  be prevented by setting `preserve_short_sequences`, which will cause the op
+  to always generate at least one ngram per non-empty sequence.
+
+  Args:
+    data: A Tensor or RaggedTensor containing the source data for the ngrams.
+    ngram_width: The width(s) of the ngrams to create. If this is a list or
+      tuple, the op will return ngrams of all specified arities in list order.
+      Values must be non-Tensor integers greater than 0.
+    separator: The separator string used between ngram elements. Must be a
+      string constant, not a Tensor.
+    pad_values: A tuple of (left_pad_value, right_pad_value), a single string,
+      or None. If None, no padding will be added; if a single string, then that
+      string will be used for both left and right padding. Values must be Python
+      strings.
+    padding_width: If set, `padding_width` pad values will be added to both
+      sides of each sequence. Defaults to `ngram_width`-1. Must be greater than
+      0. (Note that 1-grams are never padded, regardless of this value.)
+    preserve_short_sequences: If true, then ensure that at least one ngram is
+      generated for each input sequence.  In particular, if an input sequence is
+      shorter than `min(ngram_width) + 2*pad_width`, then generate a single
+      ngram containing the entire sequence.  If false, then no ngrams are
+      generated for these short input sequences.
+    name: The op name.
+
+  Returns:
+    A RaggedTensor of ngrams. If `data.shape=[D1...DN, S]`, then
+    `output.shape=[D1...DN, NUM_NGRAMS]`, where
+    `NUM_NGRAMS=S-ngram_width+1+2*padding_width`.
+
+  Raises:
+    TypeError: if `pad_values` is set to an invalid type.
+    ValueError: if `pad_values`, `padding_width`, or `ngram_width` is set to an
+      invalid value.
+  """
+
+  with ops.name_scope(name, "StringNGrams", [data]):
+    if pad_values is None:
+      left_pad = ""
+      right_pad = ""
+    elif isinstance(pad_values, (list, tuple)):
+      if (not isinstance(pad_values[0], util_compat.bytes_or_text_types) or
+          not isinstance(pad_values[1], util_compat.bytes_or_text_types)):
+        raise TypeError(
+            "pad_values must be a string, tuple of strings, or None.")
+      left_pad = pad_values[0]
+      right_pad = pad_values[1]
+    else:
+      if not isinstance(pad_values, util_compat.bytes_or_text_types):
+        raise TypeError(
+            "pad_values must be a string, tuple of strings, or None.")
+      left_pad = pad_values
+      right_pad = pad_values
+
+    if padding_width is not None and padding_width < 1:
+      raise ValueError("padding_width must be greater than 0.")
+
+    if padding_width is not None and pad_values is None:
+      raise ValueError("pad_values must be provided if padding_width is set.")
+
+    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(
+        data, name="data", dtype=dtypes.string)
+
+    if not isinstance(data, ragged_tensor.RaggedTensor):
+      if data.shape.ndims is None:
+        raise ValueError("Rank of data must be known.")
+      elif data.shape.ndims == 0:
+        raise ValueError("Data must have rank>0")
+      elif data.shape.ndims == 1:
+        rt = ragged_tensor.RaggedTensor.from_row_starts(
+            data, [0], validate=False)
+        return ngrams(rt, ngram_width, separator, pad_values, padding_width,
+                      preserve_short_sequences, name)[0]
+      else:
+        data = ragged_tensor.RaggedTensor.from_tensor(
+            data, ragged_rank=data.shape.ndims - 1)
+
+    if data.ragged_rank > 1:
+      return data.with_values(
+          ngrams(data.values, ngram_width, separator, pad_values, padding_width,
+                 preserve_short_sequences, name))
+
+    if pad_values is None:
+      padding_width = 0
+
+    if pad_values is not None and padding_width is None:
+      padding_width = -1
+
+    if not isinstance(ngram_width, (list, tuple)):
+      ngram_widths = [ngram_width]
+    else:
+      ngram_widths = ngram_width
+    for width in ngram_widths:
+      if width < 1:
+        raise ValueError("All ngram_widths must be greater than 0. Got %s" %
+                         ngram_width)
+
+    output, output_splits = gen_string_ops.string_n_grams(
+        data=data.flat_values,
+        data_splits=data.row_splits,
+        separator=separator,
+        ngram_widths=ngram_widths,
+        left_pad=left_pad,
+        right_pad=right_pad,
+        pad_width=padding_width,
+        preserve_short_sequences=preserve_short_sequences)
+
+    return ragged_tensor.RaggedTensor.from_row_splits(
+        values=output, row_splits=output_splits, validate=False)
diff --git a/tensorflow/python/ops/ragged/string_ngrams_op_test.py b/tensorflow/python/ops/ragged/string_ngrams_op_test.py
new file mode 100644
index 00000000000..a10829c50fc
--- /dev/null
+++ b/tensorflow/python/ops/ragged/string_ngrams_op_test.py
@@ -0,0 +1,250 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the b"License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an b"AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the Tensorflow strings.ngrams op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops.ragged import ragged_factory_ops
+from tensorflow.python.ops.ragged import ragged_string_ops
+from tensorflow.python.platform import test
+
+
+class StringNgramsTest(test_util.TensorFlowTestCase):
+
+  def test_unpadded_ngrams(self):
+    data = [[b"aa", b"bb", b"cc", b"dd"], [b"ee", b"ff"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=3, separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"aa|bb|cc", b"bb|cc|dd"], []]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_tuple_multi_ngrams(self):
+    data = [[b"aa", b"bb", b"cc", b"dd"], [b"ee", b"ff"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=(2, 3), separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"aa|bb", b"bb|cc", b"cc|dd", b"aa|bb|cc", b"bb|cc|dd"],
+                       [b"ee|ff"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_tuple_multi_ngrams_inverted_order(self):
+    data = [[b"aa", b"bb", b"cc", b"dd"], [b"ee", b"ff"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=(3, 2), separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"aa|bb|cc", b"bb|cc|dd", b"aa|bb", b"bb|cc", b"cc|dd"],
+                       [b"ee|ff"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_list_multi_ngrams(self):
+    data = [[b"aa", b"bb", b"cc", b"dd"], [b"ee", b"ff"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=[2, 3], separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"aa|bb", b"bb|cc", b"cc|dd", b"aa|bb|cc", b"bb|cc|dd"],
+                       [b"ee|ff"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_multi_ngram_ordering(self):
+    data = [[b"aa", b"bb", b"cc", b"dd"], [b"ee", b"ff"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=[3, 2], separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"aa|bb|cc", b"bb|cc|dd", b"aa|bb", b"bb|cc", b"cc|dd"],
+                       [b"ee|ff"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_fully_padded_ngrams(self):
+    data = [[b"a"], [b"b", b"c", b"d"], [b"e", b"f"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=3, separator=b"|", pad_values=(b"LP", b"RP"))
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [
+        [b"LP|LP|a", b"LP|a|RP", b"a|RP|RP"],  # 0
+        [b"LP|LP|b", b"LP|b|c", b"b|c|d", b"c|d|RP", b"d|RP|RP"],  # 1
+        [b"LP|LP|e", b"LP|e|f", b"e|f|RP", b"f|RP|RP"]  # 2
+    ]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_ngram_padding_size_cap(self):
+    # Validate that the padding size is never greater than ngram_size - 1.
+    data = [[b"a"], [b"b", b"c", b"d"], [b"e", b"f"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor,
+        ngram_width=3,
+        separator=b"|",
+        pad_values=(b"LP", b"RP"),
+        padding_width=10)
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [
+        [b"LP|LP|a", b"LP|a|RP", b"a|RP|RP"],  # 0
+        [b"LP|LP|b", b"LP|b|c", b"b|c|d", b"c|d|RP", b"d|RP|RP"],  # 1
+        [b"LP|LP|e", b"LP|e|f", b"e|f|RP", b"f|RP|RP"]  # 2
+    ]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_singly_padded_ngrams(self):
+    data = [[b"a"], [b"b", b"c", b"d"], [b"e", b"f"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor,
+        ngram_width=5,
+        separator=b"|",
+        pad_values=(b"LP", b"RP"),
+        padding_width=1)
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[], [b"LP|b|c|d|RP"], []]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_singly_padded_ngrams_with_preserve_short(self):
+    data = [[b"a"], [b"b", b"c", b"d"], [b"e", b"f"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor,
+        ngram_width=5,
+        separator=b"|",
+        pad_values=(b"LP", b"RP"),
+        padding_width=1,
+        preserve_short_sequences=True)
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"LP|a|RP"], [b"LP|b|c|d|RP"], [b"LP|e|f|RP"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_singly_padded_multiple_ngrams(self):
+    data = [[b"a"], [b"b", b"c", b"d"], [b"e", b"f"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor,
+        ngram_width=(1, 5),
+        separator=b"|",
+        pad_values=(b"LP", b"RP"),
+        padding_width=1)
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"a"], [b"b", b"c", b"d", b"LP|b|c|d|RP"], [b"e", b"f"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_single_padding_string(self):
+    data = [[b"a"], [b"b", b"c", b"d"], [b"e", b"f"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor,
+        ngram_width=5,
+        separator=b"|",
+        pad_values=b"[PAD]",
+        padding_width=1)
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[], [b"[PAD]|b|c|d|[PAD]"], []]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_explicit_multiply_padded_ngrams(self):
+    data = [[b"a"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor,
+        ngram_width=5,
+        separator=b"|",
+        pad_values=(b"LP", b"RP"),
+        padding_width=2)
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[b"LP|LP|a|RP|RP"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_ragged_inputs_with_multiple_ragged_dimensions(self):
+    data = [[[[b"aa", b"bb", b"cc", b"dd"]], [[b"ee", b"ff"]]]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=3, separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[[[b"aa|bb|cc", b"bb|cc|dd"]], [[]]]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_ragged_inputs_with_multiple_ragged_dimensions_and_preserve(self):
+    data = [[[[b"aa", b"bb", b"cc", b"dd"]], [[b"ee", b"ff"]]]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor,
+        ngram_width=3,
+        separator=b"|",
+        preserve_short_sequences=True)
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[[[b"aa|bb|cc", b"bb|cc|dd"]], [[b"ee|ff"]]]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_ragged_inputs_with_multiple_ragged_dimensions_bigrams(self):
+    data = [[[[b"aa", b"bb", b"cc", b"dd"]], [[b"ee", b"ff"]]]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=2, separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[[[b"aa|bb", b"bb|cc", b"cc|dd"]], [[b"ee|ff"]]]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_ragged_inputs_with_multiple_ragged_dimensions_and_multiple_ngrams(
+      self):
+    data = [[[[b"aa", b"bb", b"cc", b"dd"]], [[b"ee", b"ff"]]]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=(3, 4), separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[[[b"aa|bb|cc", b"bb|cc|dd", b"aa|bb|cc|dd"]], [[]]]]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_dense_input(self):
+    data = [[b"a", b"z"], [b"b", b""], [b"e", b"f"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=3, separator=b"|", pad_values=(b"LP", b"RP"))
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [
+        [b"LP|LP|a", b"LP|a|z", b"a|z|RP", b"z|RP|RP"],
+        [b"LP|LP|b", b"LP|b|", b"b||RP", b"|RP|RP"],
+        [b"LP|LP|e", b"LP|e|f", b"e|f|RP", b"f|RP|RP"],
+    ]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_vector_input(self):
+    data = [b"a", b"z"]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=3, separator=b"|", pad_values=(b"LP", b"RP"))
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [b"LP|LP|a", b"LP|a|z", b"a|z|RP", b"z|RP|RP"]
+    self.assertAllEqual(expected_ngrams, result)
+
+  def test_dense_input_with_multiple_ngrams(self):
+    data = [[b"a", b"b", b"c", b"d"], [b"e", b"f", b"g", b"h"]]
+    data_tensor = ragged_factory_ops.constant(data)
+    ngram_op = ragged_string_ops.ngrams(
+        data_tensor, ngram_width=(1, 2, 3), separator=b"|")
+    result = self.evaluate(ngram_op)
+    expected_ngrams = [[
+        b"a", b"b", b"c", b"d", b"a|b", b"b|c", b"c|d", b"a|b|c", b"b|c|d"
+    ], [b"e", b"f", b"g", b"h", b"e|f", b"f|g", b"g|h", b"e|f|g", b"f|g|h"]]
+    self.assertAllEqual(expected_ngrams, result)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 473323b088c..ccbf09cb4e2 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -4108,6 +4108,10 @@ tf_module {
     name: "StringLower"
     argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
   }
+  member_method {
+    name: "StringNGrams"
+    argspec: "args=[\'data\', \'data_splits\', \'separator\', \'ngram_widths\', \'left_pad\', \'right_pad\', \'pad_width\', \'preserve_short_sequences\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
   member_method {
     name: "StringSplit"
     argspec: "args=[\'input\', \'delimiter\', \'skip_empty\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
index 1a73ab6a7e5..b5008339866 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
@@ -24,6 +24,10 @@ tf_module {
     name: "lower"
     argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
   }
+  member_method {
+    name: "ngrams"
+    argspec: "args=[\'data\', \'ngram_width\', \'separator\', \'pad_values\', \'padding_width\', \'preserve_short_sequences\', \'name\'], varargs=None, keywords=None, defaults=[\' \', \'None\', \'None\', \'False\', \'None\'], "
+  }
   member_method {
     name: "reduce_join"
     argspec: "args=[\'inputs\', \'axis\', \'keep_dims\', \'separator\', \'name\', \'reduction_indices\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'\', \'None\', \'None\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 473323b088c..ccbf09cb4e2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -4108,6 +4108,10 @@ tf_module {
     name: "StringLower"
     argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
   }
+  member_method {
+    name: "StringNGrams"
+    argspec: "args=[\'data\', \'data_splits\', \'separator\', \'ngram_widths\', \'left_pad\', \'right_pad\', \'pad_width\', \'preserve_short_sequences\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
   member_method {
     name: "StringSplit"
     argspec: "args=[\'input\', \'delimiter\', \'skip_empty\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
index 6f0cd870f6f..8fc27ccedab 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
@@ -24,6 +24,10 @@ tf_module {
     name: "lower"
     argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
   }
+  member_method {
+    name: "ngrams"
+    argspec: "args=[\'data\', \'ngram_width\', \'separator\', \'pad_values\', \'padding_width\', \'preserve_short_sequences\', \'name\'], varargs=None, keywords=None, defaults=[\' \', \'None\', \'None\', \'False\', \'None\'], "
+  }
   member_method {
     name: "reduce_join"
     argspec: "args=[\'inputs\', \'axis\', \'keepdims\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'\', \'None\'], "