From add7a1a911b430ed14f8b6a1609dd3796587d131 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Thu, 7 Mar 2019 13:49:48 -0800 Subject: [PATCH] Make tf.contrib.proto.* TF2-friendly. This included fixing a bug where shape inference caught an incorrect shape, but since eager mode doesn't run shape inference the core code caused a segfault. PiperOrigin-RevId: 237316781 --- .../kernel_tests/decode_proto_op_test_base.py | 21 ++-- .../kernel_tests/encode_proto_op_test_base.py | 112 +++++++++++------- tensorflow/core/kernels/encode_proto_op.cc | 11 +- 3 files changed, 90 insertions(+), 54 deletions(-) diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py index c8524e98718..13749837e0c 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py @@ -296,14 +296,13 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): field_names = ['sizes'] field_types = [dtypes.int32] - with self.cached_session() as sess: - ctensor, vtensor = self._decode_module.decode_proto( - batch, - message_type=msg_type, - field_names=field_names, - output_types=field_types, - sanitize=sanitize) - with self.assertRaisesRegexp(errors.DataLossError, - 'Unable to parse binary protobuf' - '|Failed to consume entire buffer'): - _ = sess.run([ctensor] + vtensor) + with self.assertRaisesRegexp( + errors.DataLossError, 'Unable to parse binary protobuf' + '|Failed to consume entire buffer'): + self.evaluate( + self._decode_module.decode_proto( + batch, + message_type=msg_type, + field_names=field_names, + output_types=field_types, + sanitize=sanitize)) diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py index 5ec681ff55d..fac2453527d 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py +++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py @@ -30,7 +30,9 @@ from google.protobuf import text_format from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops @@ -50,56 +52,86 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): self._decode_module = decode_module self._encode_module = encode_module + def testBadSizesShape(self): + if context.executing_eagerly(): + expected_error = (errors.InvalidArgumentError, + r'Invalid shape for field double_value.') + else: + expected_error = (ValueError, + r'Shape must be at least rank 2 but is rank 0') + with self.assertRaisesRegexp(*expected_error): + self.evaluate( + self._encode_module.encode_proto( + sizes=1, + values=[np.double(1.0)], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value'])) + def testBadInputs(self): # Invalid field name - with self.cached_session(): - with self.assertRaisesOpError('Unknown field: non_existent_field'): - self._encode_module.encode_proto( - sizes=[[1]], - values=[np.array([[0.0]], dtype=np.int32)], - message_type='tensorflow.contrib.proto.TestValue', - field_names=['non_existent_field']).eval() + with self.assertRaisesOpError('Unknown field: non_existent_field'): + self.evaluate( + self._encode_module.encode_proto( + sizes=[[1]], + values=[np.array([[0.0]], dtype=np.int32)], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['non_existent_field'])) # Incorrect types. - with self.cached_session(): - with self.assertRaisesOpError( - 'Incompatible type for field double_value.'): - self._encode_module.encode_proto( - sizes=[[1]], - values=[np.array([[0.0]], dtype=np.int32)], - message_type='tensorflow.contrib.proto.TestValue', - field_names=['double_value']).eval() + with self.assertRaisesOpError('Incompatible type for field double_value.'): + self.evaluate( + self._encode_module.encode_proto( + sizes=[[1]], + values=[np.array([[0.0]], dtype=np.int32)], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value'])) # Incorrect shapes of sizes. - with self.cached_session(): + for sizes_value in 1, np.array([[[0, 0]]]): with self.assertRaisesOpError( r'sizes should be batch_size \+ \[len\(field_names\)\]'): - sizes = array_ops.placeholder(dtypes.int32) - values = array_ops.placeholder(dtypes.float64) - self._encode_module.encode_proto( - sizes=sizes, - values=[values], - message_type='tensorflow.contrib.proto.TestValue', - field_names=['double_value']).eval(feed_dict={ - sizes: [[[0, 0]]], - values: [[0.0]] - }) + if context.executing_eagerly(): + self.evaluate( + self._encode_module.encode_proto( + sizes=sizes_value, + values=[np.array([[0.0]])], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value'])) + else: + with self.cached_session(): + sizes = array_ops.placeholder(dtypes.int32) + values = array_ops.placeholder(dtypes.float64) + self._encode_module.encode_proto( + sizes=sizes, + values=[values], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value']).eval(feed_dict={ + sizes: sizes_value, + values: [[0.0]] + }) # Inconsistent shapes of values. - with self.cached_session(): - with self.assertRaisesOpError( - 'Values must match up to the last dimension'): - sizes = array_ops.placeholder(dtypes.int32) - values1 = array_ops.placeholder(dtypes.float64) - values2 = array_ops.placeholder(dtypes.int32) - (self._encode_module.encode_proto( - sizes=[[1, 1]], - values=[values1, values2], - message_type='tensorflow.contrib.proto.TestValue', - field_names=['double_value', 'int32_value']).eval(feed_dict={ - values1: [[0.0]], - values2: [[0], [0]] - })) + with self.assertRaisesOpError('Values must match up to the last dimension'): + if context.executing_eagerly(): + self.evaluate( + self._encode_module.encode_proto( + sizes=[[1, 1]], + values=[np.array([[0.0]]), + np.array([[0], [0]])], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value', 'int32_value'])) + else: + with self.cached_session(): + values1 = array_ops.placeholder(dtypes.float64) + values2 = array_ops.placeholder(dtypes.int32) + (self._encode_module.encode_proto( + sizes=[[1, 1]], + values=[values1, values2], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value', 'int32_value']).eval(feed_dict={ + values1: [[0.0]], + values2: [[0], [0]] + })) def _testRoundtrip(self, in_bufs, message_type, fields): diff --git a/tensorflow/core/kernels/encode_proto_op.cc b/tensorflow/core/kernels/encode_proto_op.cc index 4a0c1943e54..213c63f41ae 100644 --- a/tensorflow/core/kernels/encode_proto_op.cc +++ b/tensorflow/core/kernels/encode_proto_op.cc @@ -525,11 +525,16 @@ class EncodeProtoOp : public OpKernel { ctx, proto_utils::IsCompatibleType(field_descs_[i]->type(), v.dtype()), errors::InvalidArgument( - "Incompatible type for field " + field_names_[i] + - ". Saw dtype: ", - DataTypeString(v.dtype()), + "Incompatible type for field ", field_names_[i], + ". Saw dtype: ", DataTypeString(v.dtype()), " but field type is: ", field_descs_[i]->type_name())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrixOrHigher(v.shape()), + errors::InvalidArgument("Invalid shape for field ", field_names_[i], + ". Saw shape ", v.shape().DebugString(), + " but it should be at least a matrix.")); + // All value tensors must have the same shape prefix (i.e. batch size). TensorShape shape_prefix = v.shape(); shape_prefix.RemoveDim(shape_prefix.dims() - 1);