Make tf.contrib.proto.* TF2-friendly.

This included fixing a bug where shape inference caught an incorrect shape,
but since eager mode doesn't run shape inference the core code caused a
segfault.

PiperOrigin-RevId: 237316781
This commit is contained in:
Eugene Brevdo 2019-03-07 13:49:48 -08:00 committed by TensorFlower Gardener
parent 17baa62ff4
commit add7a1a911
3 changed files with 90 additions and 54 deletions

View File

@ -296,14 +296,13 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
field_names = ['sizes']
field_types = [dtypes.int32]
with self.cached_session() as sess:
ctensor, vtensor = self._decode_module.decode_proto(
batch,
message_type=msg_type,
field_names=field_names,
output_types=field_types,
sanitize=sanitize)
with self.assertRaisesRegexp(errors.DataLossError,
'Unable to parse binary protobuf'
'|Failed to consume entire buffer'):
_ = sess.run([ctensor] + vtensor)
with self.assertRaisesRegexp(
errors.DataLossError, 'Unable to parse binary protobuf'
'|Failed to consume entire buffer'):
self.evaluate(
self._decode_module.decode_proto(
batch,
message_type=msg_type,
field_names=field_names,
output_types=field_types,
sanitize=sanitize))

View File

@ -30,7 +30,9 @@ from google.protobuf import text_format
from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base
from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
@ -50,56 +52,86 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
self._decode_module = decode_module
self._encode_module = encode_module
def testBadSizesShape(self):
if context.executing_eagerly():
expected_error = (errors.InvalidArgumentError,
r'Invalid shape for field double_value.')
else:
expected_error = (ValueError,
r'Shape must be at least rank 2 but is rank 0')
with self.assertRaisesRegexp(*expected_error):
self.evaluate(
self._encode_module.encode_proto(
sizes=1,
values=[np.double(1.0)],
message_type='tensorflow.contrib.proto.TestValue',
field_names=['double_value']))
def testBadInputs(self):
# Invalid field name
with self.cached_session():
with self.assertRaisesOpError('Unknown field: non_existent_field'):
self._encode_module.encode_proto(
sizes=[[1]],
values=[np.array([[0.0]], dtype=np.int32)],
message_type='tensorflow.contrib.proto.TestValue',
field_names=['non_existent_field']).eval()
with self.assertRaisesOpError('Unknown field: non_existent_field'):
self.evaluate(
self._encode_module.encode_proto(
sizes=[[1]],
values=[np.array([[0.0]], dtype=np.int32)],
message_type='tensorflow.contrib.proto.TestValue',
field_names=['non_existent_field']))
# Incorrect types.
with self.cached_session():
with self.assertRaisesOpError(
'Incompatible type for field double_value.'):
self._encode_module.encode_proto(
sizes=[[1]],
values=[np.array([[0.0]], dtype=np.int32)],
message_type='tensorflow.contrib.proto.TestValue',
field_names=['double_value']).eval()
with self.assertRaisesOpError('Incompatible type for field double_value.'):
self.evaluate(
self._encode_module.encode_proto(
sizes=[[1]],
values=[np.array([[0.0]], dtype=np.int32)],
message_type='tensorflow.contrib.proto.TestValue',
field_names=['double_value']))
# Incorrect shapes of sizes.
with self.cached_session():
for sizes_value in 1, np.array([[[0, 0]]]):
with self.assertRaisesOpError(
r'sizes should be batch_size \+ \[len\(field_names\)\]'):
sizes = array_ops.placeholder(dtypes.int32)
values = array_ops.placeholder(dtypes.float64)
self._encode_module.encode_proto(
sizes=sizes,
values=[values],
message_type='tensorflow.contrib.proto.TestValue',
field_names=['double_value']).eval(feed_dict={
sizes: [[[0, 0]]],
values: [[0.0]]
})
if context.executing_eagerly():
self.evaluate(
self._encode_module.encode_proto(
sizes=sizes_value,
values=[np.array([[0.0]])],
message_type='tensorflow.contrib.proto.TestValue',
field_names=['double_value']))
else:
with self.cached_session():
sizes = array_ops.placeholder(dtypes.int32)
values = array_ops.placeholder(dtypes.float64)
self._encode_module.encode_proto(
sizes=sizes,
values=[values],
message_type='tensorflow.contrib.proto.TestValue',
field_names=['double_value']).eval(feed_dict={
sizes: sizes_value,
values: [[0.0]]
})
# Inconsistent shapes of values.
with self.cached_session():
with self.assertRaisesOpError(
'Values must match up to the last dimension'):
sizes = array_ops.placeholder(dtypes.int32)
values1 = array_ops.placeholder(dtypes.float64)
values2 = array_ops.placeholder(dtypes.int32)
(self._encode_module.encode_proto(
sizes=[[1, 1]],
values=[values1, values2],
message_type='tensorflow.contrib.proto.TestValue',
field_names=['double_value', 'int32_value']).eval(feed_dict={
values1: [[0.0]],
values2: [[0], [0]]
}))
with self.assertRaisesOpError('Values must match up to the last dimension'):
if context.executing_eagerly():
self.evaluate(
self._encode_module.encode_proto(
sizes=[[1, 1]],
values=[np.array([[0.0]]),
np.array([[0], [0]])],
message_type='tensorflow.contrib.proto.TestValue',
field_names=['double_value', 'int32_value']))
else:
with self.cached_session():
values1 = array_ops.placeholder(dtypes.float64)
values2 = array_ops.placeholder(dtypes.int32)
(self._encode_module.encode_proto(
sizes=[[1, 1]],
values=[values1, values2],
message_type='tensorflow.contrib.proto.TestValue',
field_names=['double_value', 'int32_value']).eval(feed_dict={
values1: [[0.0]],
values2: [[0], [0]]
}))
def _testRoundtrip(self, in_bufs, message_type, fields):

View File

@ -525,11 +525,16 @@ class EncodeProtoOp : public OpKernel {
ctx,
proto_utils::IsCompatibleType(field_descs_[i]->type(), v.dtype()),
errors::InvalidArgument(
"Incompatible type for field " + field_names_[i] +
". Saw dtype: ",
DataTypeString(v.dtype()),
"Incompatible type for field ", field_names_[i],
". Saw dtype: ", DataTypeString(v.dtype()),
" but field type is: ", field_descs_[i]->type_name()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsMatrixOrHigher(v.shape()),
errors::InvalidArgument("Invalid shape for field ", field_names_[i],
". Saw shape ", v.shape().DebugString(),
" but it should be at least a matrix."));
// All value tensors must have the same shape prefix (i.e. batch size).
TensorShape shape_prefix = v.shape();
shape_prefix.RemoveDim(shape_prefix.dims() - 1);