For ops with an input for the form "a: N * T", if N=0 and there's no other information about what type T should resolve to, then set T to its default value (if one is present).

PiperOrigin-RevId: 271996738
This commit is contained in:
Edward Loper 2019-09-30 09:15:54 -07:00 committed by TensorFlower Gardener
parent 096554d255
commit 4d35b20572
5 changed files with 54 additions and 21 deletions

View File

@ -235,6 +235,8 @@ def make_tensor(v, arg_name):
def args_to_matching_eager(l, ctx, default_dtype=None):
"""Convert sequence `l` to eager same-type Tensors."""
if (not l) and (default_dtype is not None):
return default_dtype, [] # List is empty; assume default dtype.
EagerTensor = ops.EagerTensor # pylint: disable=invalid-name
for x in l:
if not isinstance(x, EagerTensor):

View File

@ -326,16 +326,18 @@ class OpsTest(test_util.TensorFlowTestCase):
# Uses default
ctx = context.context()
t, r = execute.args_to_matching_eager([[3, 4]], ctx, dtypes.int32)
self.assertEquals(t, dtypes.int32)
self.assertEquals(r[0].dtype, dtypes.int32)
self.assertEqual(t, dtypes.int32)
self.assertEqual(r[0].dtype, dtypes.int32)
t, r = execute.args_to_matching_eager([[3, 4]], ctx, dtypes.int64)
self.assertEquals(t, dtypes.int64)
self.assertEquals(r[0].dtype, dtypes.int64)
self.assertEqual(t, dtypes.int64)
self.assertEqual(r[0].dtype, dtypes.int64)
t, r = execute.args_to_matching_eager([], ctx, dtypes.int64)
self.assertEqual(t, dtypes.int64)
# Doesn't use default
t, r = execute.args_to_matching_eager(
[['string', 'arg']], ctx, dtypes.int32)
self.assertEquals(t, dtypes.string)
self.assertEquals(r[0].dtype, dtypes.string)
self.assertEqual(t, dtypes.string)
self.assertEqual(r[0].dtype, dtypes.string)
def testFlattenLayer(self):
flatten_layer = core.Flatten()

View File

@ -545,15 +545,19 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
# <number-attr> * <type-attr> case, where we are now setting
# the <type-attr> based on this input
if not base_types:
raise TypeError(
"Don't know how to infer type variable from empty input "
"list passed to input '%s' of '%s' Op." %
(input_name, op_type_name))
attrs[input_arg.type_attr] = base_types[0]
inferred_from[input_arg.type_attr] = input_name
type_attr = _Attr(op_def, input_arg.type_attr)
_SatisfiesTypeConstraint(base_types[0], type_attr,
param_name=input_name)
# If it's in default_type_attr_map, then wait to set it
# (in "process remaining attrs", below).
if input_arg.type_attr not in default_type_attr_map:
raise TypeError(
"Don't know how to infer type variable from empty input "
"list passed to input '%s' of '%s' Op." %
(input_name, op_type_name))
else:
attrs[input_arg.type_attr] = base_types[0]
inferred_from[input_arg.type_attr] = input_name
type_attr = _Attr(op_def, input_arg.type_attr)
_SatisfiesTypeConstraint(base_types[0], type_attr,
param_name=input_name)
elif input_arg.type_attr:
# <type-attr>
attr_value = base_types[0]
@ -620,6 +624,9 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
# Attrs whose names match Python keywords have an extra '_'
# appended, so we must check for that as well.
attrs[attr.name] = keywords.pop(attr.name + "_")
elif attr.name in default_type_attr_map:
attrs[attr.name] = default_type_attr_map[attr.name]
inferred_from.setdefault(attr.name, "Default in OpDef")
else:
raise TypeError("No argument for attr " + attr.name)

View File

@ -1048,11 +1048,33 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
attr { key: 'M' value { i: 0 } }
""", op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5])
self.assertEqual(str(cm.exception),
"Don't know how to infer type variable from empty input "
"list passed to input 'a' of 'InPolymorphicTwice' Op.")
op = op_def_library.apply_op(
"InPolymorphicTwice", a=[], b=[3, 4], name="p")
self.assertProtoEquals("""
name: 'p' op: 'InPolymorphicTwice' input: 'p/b_0' input: 'p/b_1'
attr { key: 'T' value { type: DT_INT32 } }
attr { key: 'N' value { i: 0 } }
attr { key: 'M' value { i: 2 } }
""", op.node_def)
op = op_def_library.apply_op(
"InPolymorphicTwice", a=[], b=[3.0, 4.0], name="q")
self.assertProtoEquals("""
name: 'q' op: 'InPolymorphicTwice' input: 'q/b_0' input: 'q/b_1'
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'N' value { i: 0 } }
attr { key: 'M' value { i: 2 } }
""", op.node_def)
# Empty input lists: assume defaut type for T.
op = op_def_library.apply_op(
"InPolymorphicTwice", a=[], b=[], name="r")
self.assertProtoEquals("""
name: 'r' op: 'InPolymorphicTwice'
attr { key: 'T' value { type: DT_INT32 } }
attr { key: 'N' value { i: 0 } }
attr { key: 'M' value { i: 0 } }
""", op.node_def)
with self.assertRaises(TypeError) as cm:
op_def_library.apply_op(

View File

@ -598,7 +598,7 @@ REGISTER_OP("NInTwoTypeVariables")
REGISTER_OP("InPolymorphicTwice")
.Input("a: N * T")
.Input("b: M * T")
.Attr("T: type")
.Attr("T: type = DT_INT32")
.Attr("N: int >= 0")
.Attr("M: int >= 0")
.SetShapeFn(shape_inference::UnknownShape);