Make tf.contrib.proto.* TF2-friendly.

This included fixing a bug where shape inference caught an incorrect shape,
but since eager mode doesn't run shape inference the core code caused a
segfault.

PiperOrigin-RevId: 237316781
This commit is contained in:
Eugene Brevdo 2019-03-07 13:49:48 -08:00 committed by TensorFlower Gardener
parent 17baa62ff4
commit add7a1a911
3 changed files with 90 additions and 54 deletions

View File

@ -296,14 +296,13 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
field_names = ['sizes'] field_names = ['sizes']
field_types = [dtypes.int32] field_types = [dtypes.int32]
with self.cached_session() as sess: with self.assertRaisesRegexp(
ctensor, vtensor = self._decode_module.decode_proto( errors.DataLossError, 'Unable to parse binary protobuf'
batch, '|Failed to consume entire buffer'):
message_type=msg_type, self.evaluate(
field_names=field_names, self._decode_module.decode_proto(
output_types=field_types, batch,
sanitize=sanitize) message_type=msg_type,
with self.assertRaisesRegexp(errors.DataLossError, field_names=field_names,
'Unable to parse binary protobuf' output_types=field_types,
'|Failed to consume entire buffer'): sanitize=sanitize))
_ = sess.run([ctensor] + vtensor)

View File

@ -30,7 +30,9 @@ from google.protobuf import text_format
from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base
from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -50,56 +52,86 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
self._decode_module = decode_module self._decode_module = decode_module
self._encode_module = encode_module self._encode_module = encode_module
def testBadSizesShape(self):
if context.executing_eagerly():
expected_error = (errors.InvalidArgumentError,
r'Invalid shape for field double_value.')
else:
expected_error = (ValueError,
r'Shape must be at least rank 2 but is rank 0')
with self.assertRaisesRegexp(*expected_error):
self.evaluate(
self._encode_module.encode_proto(
sizes=1,
values=[np.double(1.0)],
message_type='tensorflow.contrib.proto.TestValue',
field_names=['double_value']))
def testBadInputs(self): def testBadInputs(self):
# Invalid field name # Invalid field name
with self.cached_session(): with self.assertRaisesOpError('Unknown field: non_existent_field'):
with self.assertRaisesOpError('Unknown field: non_existent_field'): self.evaluate(
self._encode_module.encode_proto( self._encode_module.encode_proto(
sizes=[[1]], sizes=[[1]],
values=[np.array([[0.0]], dtype=np.int32)], values=[np.array([[0.0]], dtype=np.int32)],
message_type='tensorflow.contrib.proto.TestValue', message_type='tensorflow.contrib.proto.TestValue',
field_names=['non_existent_field']).eval() field_names=['non_existent_field']))
# Incorrect types. # Incorrect types.
with self.cached_session(): with self.assertRaisesOpError('Incompatible type for field double_value.'):
with self.assertRaisesOpError( self.evaluate(
'Incompatible type for field double_value.'): self._encode_module.encode_proto(
self._encode_module.encode_proto( sizes=[[1]],
sizes=[[1]], values=[np.array([[0.0]], dtype=np.int32)],
values=[np.array([[0.0]], dtype=np.int32)], message_type='tensorflow.contrib.proto.TestValue',
message_type='tensorflow.contrib.proto.TestValue', field_names=['double_value']))
field_names=['double_value']).eval()
# Incorrect shapes of sizes. # Incorrect shapes of sizes.
with self.cached_session(): for sizes_value in 1, np.array([[[0, 0]]]):
with self.assertRaisesOpError( with self.assertRaisesOpError(
r'sizes should be batch_size \+ \[len\(field_names\)\]'): r'sizes should be batch_size \+ \[len\(field_names\)\]'):
sizes = array_ops.placeholder(dtypes.int32) if context.executing_eagerly():
values = array_ops.placeholder(dtypes.float64) self.evaluate(
self._encode_module.encode_proto( self._encode_module.encode_proto(
sizes=sizes, sizes=sizes_value,
values=[values], values=[np.array([[0.0]])],
message_type='tensorflow.contrib.proto.TestValue', message_type='tensorflow.contrib.proto.TestValue',
field_names=['double_value']).eval(feed_dict={ field_names=['double_value']))
sizes: [[[0, 0]]], else:
values: [[0.0]] with self.cached_session():
}) sizes = array_ops.placeholder(dtypes.int32)
values = array_ops.placeholder(dtypes.float64)
self._encode_module.encode_proto(
sizes=sizes,
values=[values],
message_type='tensorflow.contrib.proto.TestValue',
field_names=['double_value']).eval(feed_dict={
sizes: sizes_value,
values: [[0.0]]
})
# Inconsistent shapes of values. # Inconsistent shapes of values.
with self.cached_session(): with self.assertRaisesOpError('Values must match up to the last dimension'):
with self.assertRaisesOpError( if context.executing_eagerly():
'Values must match up to the last dimension'): self.evaluate(
sizes = array_ops.placeholder(dtypes.int32) self._encode_module.encode_proto(
values1 = array_ops.placeholder(dtypes.float64) sizes=[[1, 1]],
values2 = array_ops.placeholder(dtypes.int32) values=[np.array([[0.0]]),
(self._encode_module.encode_proto( np.array([[0], [0]])],
sizes=[[1, 1]], message_type='tensorflow.contrib.proto.TestValue',
values=[values1, values2], field_names=['double_value', 'int32_value']))
message_type='tensorflow.contrib.proto.TestValue', else:
field_names=['double_value', 'int32_value']).eval(feed_dict={ with self.cached_session():
values1: [[0.0]], values1 = array_ops.placeholder(dtypes.float64)
values2: [[0], [0]] values2 = array_ops.placeholder(dtypes.int32)
})) (self._encode_module.encode_proto(
sizes=[[1, 1]],
values=[values1, values2],
message_type='tensorflow.contrib.proto.TestValue',
field_names=['double_value', 'int32_value']).eval(feed_dict={
values1: [[0.0]],
values2: [[0], [0]]
}))
def _testRoundtrip(self, in_bufs, message_type, fields): def _testRoundtrip(self, in_bufs, message_type, fields):

View File

@ -525,11 +525,16 @@ class EncodeProtoOp : public OpKernel {
ctx, ctx,
proto_utils::IsCompatibleType(field_descs_[i]->type(), v.dtype()), proto_utils::IsCompatibleType(field_descs_[i]->type(), v.dtype()),
errors::InvalidArgument( errors::InvalidArgument(
"Incompatible type for field " + field_names_[i] + "Incompatible type for field ", field_names_[i],
". Saw dtype: ", ". Saw dtype: ", DataTypeString(v.dtype()),
DataTypeString(v.dtype()),
" but field type is: ", field_descs_[i]->type_name())); " but field type is: ", field_descs_[i]->type_name()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsMatrixOrHigher(v.shape()),
errors::InvalidArgument("Invalid shape for field ", field_names_[i],
". Saw shape ", v.shape().DebugString(),
" but it should be at least a matrix."));
// All value tensors must have the same shape prefix (i.e. batch size). // All value tensors must have the same shape prefix (i.e. batch size).
TensorShape shape_prefix = v.shape(); TensorShape shape_prefix = v.shape();
shape_prefix.RemoveDim(shape_prefix.dims() - 1); shape_prefix.RemoveDim(shape_prefix.dims() - 1);