diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index e382381e122..0c94ba35b24 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -51,6 +51,12 @@ class SubstrOp : public OpKernel { const Tensor& len_tensor = context->input(2); const TensorShape& input_shape = input_tensor.shape(); const TensorShape& pos_shape = pos_tensor.shape(); + const TensorShape& len_shape = len_tensor.shape(); + OP_REQUIRES( + context, (pos_shape == len_shape), + errors::InvalidArgument("pos and len should have the same shape, got: ", + pos_shape.DebugString(), " vs. ", + len_shape.DebugString())); bool is_scalar = TensorShapeUtils::IsScalar(pos_shape); diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py index 9302152e82b..ad7b6050c29 100644 --- a/tensorflow/python/kernel_tests/substr_op_test.py +++ b/tensorflow/python/kernel_tests/substr_op_test.py @@ -492,6 +492,16 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase): with self.assertRaises(ValueError): string_ops.substr(b"test", 3, 1, unit="UTF8") + def testInvalidPos(self): + # Test case for GitHub issue 46900. + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): + x = string_ops.substr(b"abc", len=1, pos=[1, -1]) + self.evaluate(x) + + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): + x = string_ops.substr(b"abc", len=1, pos=[1, 2]) + self.evaluate(x) + if __name__ == "__main__": test.main()