Merge pull request #46974 from yongtang:46900-tf.strings.substr

PiperOrigin-RevId: 356520125
Change-Id: Ifd11ff02aa51023007201dccfa02eb8213c08f7a
This commit is contained in:
TensorFlower Gardener 2021-02-09 09:39:39 -08:00
commit 890f7164b7
2 changed files with 15 additions and 0 deletions

View File

@ -51,6 +51,11 @@ class SubstrOp : public OpKernel {
const Tensor& len_tensor = context->input(2);
const TensorShape& input_shape = input_tensor.shape();
const TensorShape& pos_shape = pos_tensor.shape();
const TensorShape& len_shape = len_tensor.shape();
OP_REQUIRES(context, (pos_shape == len_shape),
errors::InvalidArgument(
"pos and len should have the same shape, got: ",
pos_shape.DebugString(), " vs. ", len_shape.DebugString()));
bool is_scalar = TensorShapeUtils::IsScalar(pos_shape);

View File

@ -492,6 +492,16 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(ValueError):
string_ops.substr(b"test", 3, 1, unit="UTF8")
def testInvalidPos(self):
# Test case for GitHub issue 46900.
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
x = string_ops.substr(b"abc", len=1, pos=[1, -1])
self.evaluate(x)
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
x = string_ops.substr(b"abc", len=1, pos=[1, 2])
self.evaluate(x)
if __name__ == "__main__":
test.main()