Fix crash of tf.strings.substr when pos and len have different shapes
This PR tries to address the issue raised in 46900 where tf.strings.substr will crash when pos and len have different shapes. According to the documentation of tf.strings.substr, ValueError should be raised instead when pos and len does not have the same shape. This PR add shape check in kernel to allows grace error throw (instead of crash). This PR fixes 46900. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
844002ed99
commit
7edb8c9b83
@ -51,6 +51,12 @@ class SubstrOp : public OpKernel {
|
||||
const Tensor& len_tensor = context->input(2);
|
||||
const TensorShape& input_shape = input_tensor.shape();
|
||||
const TensorShape& pos_shape = pos_tensor.shape();
|
||||
const TensorShape& len_shape = len_tensor.shape();
|
||||
OP_REQUIRES(
|
||||
context, (pos_shape == len_shape),
|
||||
errors::InvalidArgument("pos and len should have the same shape, got: ",
|
||||
pos_shape.DebugString(), " vs. ",
|
||||
len_shape.DebugString()));
|
||||
|
||||
bool is_scalar = TensorShapeUtils::IsScalar(pos_shape);
|
||||
|
||||
|
@ -492,6 +492,16 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
string_ops.substr(b"test", 3, 1, unit="UTF8")
|
||||
|
||||
def testInvalidPos(self):
|
||||
# Test case for GitHub issue 46900.
|
||||
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||
x = string_ops.substr(b"abc", len=1, pos=[1, -1])
|
||||
self.evaluate(x)
|
||||
|
||||
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||
x = string_ops.substr(b"abc", len=1, pos=[1, 2])
|
||||
self.evaluate(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user