diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py index 30c9ca6217d..92e22beefae 100644 --- a/tensorflow/python/eager/execute.py +++ b/tensorflow/python/eager/execute.py @@ -235,6 +235,8 @@ def make_tensor(v, arg_name): def args_to_matching_eager(l, ctx, default_dtype=None): """Convert sequence `l` to eager same-type Tensors.""" + if (not l) and (default_dtype is not None): + return default_dtype, [] # List is empty; assume default dtype. EagerTensor = ops.EagerTensor # pylint: disable=invalid-name for x in l: if not isinstance(x, EagerTensor): diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index cebf786d6a8..c7748fd12e1 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -326,16 +326,18 @@ class OpsTest(test_util.TensorFlowTestCase): # Uses default ctx = context.context() t, r = execute.args_to_matching_eager([[3, 4]], ctx, dtypes.int32) - self.assertEquals(t, dtypes.int32) - self.assertEquals(r[0].dtype, dtypes.int32) + self.assertEqual(t, dtypes.int32) + self.assertEqual(r[0].dtype, dtypes.int32) t, r = execute.args_to_matching_eager([[3, 4]], ctx, dtypes.int64) - self.assertEquals(t, dtypes.int64) - self.assertEquals(r[0].dtype, dtypes.int64) + self.assertEqual(t, dtypes.int64) + self.assertEqual(r[0].dtype, dtypes.int64) + t, r = execute.args_to_matching_eager([], ctx, dtypes.int64) + self.assertEqual(t, dtypes.int64) # Doesn't use default t, r = execute.args_to_matching_eager( [['string', 'arg']], ctx, dtypes.int32) - self.assertEquals(t, dtypes.string) - self.assertEquals(r[0].dtype, dtypes.string) + self.assertEqual(t, dtypes.string) + self.assertEqual(r[0].dtype, dtypes.string) def testFlattenLayer(self): flatten_layer = core.Flatten() diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index ec780f26d6b..e8479b10bed 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -545,15 +545,19 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in # * case, where we are now setting # the based on this input if not base_types: - raise TypeError( - "Don't know how to infer type variable from empty input " - "list passed to input '%s' of '%s' Op." % - (input_name, op_type_name)) - attrs[input_arg.type_attr] = base_types[0] - inferred_from[input_arg.type_attr] = input_name - type_attr = _Attr(op_def, input_arg.type_attr) - _SatisfiesTypeConstraint(base_types[0], type_attr, - param_name=input_name) + # If it's in default_type_attr_map, then wait to set it + # (in "process remaining attrs", below). + if input_arg.type_attr not in default_type_attr_map: + raise TypeError( + "Don't know how to infer type variable from empty input " + "list passed to input '%s' of '%s' Op." % + (input_name, op_type_name)) + else: + attrs[input_arg.type_attr] = base_types[0] + inferred_from[input_arg.type_attr] = input_name + type_attr = _Attr(op_def, input_arg.type_attr) + _SatisfiesTypeConstraint(base_types[0], type_attr, + param_name=input_name) elif input_arg.type_attr: # attr_value = base_types[0] @@ -620,6 +624,9 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") + elif attr.name in default_type_attr_map: + attrs[attr.name] = default_type_attr_map[attr.name] + inferred_from.setdefault(attr.name, "Default in OpDef") else: raise TypeError("No argument for attr " + attr.name) diff --git a/tensorflow/python/framework/op_def_library_test.py b/tensorflow/python/framework/op_def_library_test.py index 83990d7648b..dda42f246e0 100644 --- a/tensorflow/python/framework/op_def_library_test.py +++ b/tensorflow/python/framework/op_def_library_test.py @@ -1048,11 +1048,33 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): attr { key: 'M' value { i: 0 } } """, op.node_def) - with self.assertRaises(TypeError) as cm: - op_def_library.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5]) - self.assertEqual(str(cm.exception), - "Don't know how to infer type variable from empty input " - "list passed to input 'a' of 'InPolymorphicTwice' Op.") + op = op_def_library.apply_op( + "InPolymorphicTwice", a=[], b=[3, 4], name="p") + self.assertProtoEquals(""" + name: 'p' op: 'InPolymorphicTwice' input: 'p/b_0' input: 'p/b_1' + attr { key: 'T' value { type: DT_INT32 } } + attr { key: 'N' value { i: 0 } } + attr { key: 'M' value { i: 2 } } + """, op.node_def) + + op = op_def_library.apply_op( + "InPolymorphicTwice", a=[], b=[3.0, 4.0], name="q") + self.assertProtoEquals(""" + name: 'q' op: 'InPolymorphicTwice' input: 'q/b_0' input: 'q/b_1' + attr { key: 'T' value { type: DT_FLOAT } } + attr { key: 'N' value { i: 0 } } + attr { key: 'M' value { i: 2 } } + """, op.node_def) + + # Empty input lists: assume defaut type for T. + op = op_def_library.apply_op( + "InPolymorphicTwice", a=[], b=[], name="r") + self.assertProtoEquals(""" + name: 'r' op: 'InPolymorphicTwice' + attr { key: 'T' value { type: DT_INT32 } } + attr { key: 'N' value { i: 0 } } + attr { key: 'M' value { i: 0 } } + """, op.node_def) with self.assertRaises(TypeError) as cm: op_def_library.apply_op( diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc index 623081c2dc5..fc864692b7b 100644 --- a/tensorflow/python/framework/test_ops.cc +++ b/tensorflow/python/framework/test_ops.cc @@ -598,7 +598,7 @@ REGISTER_OP("NInTwoTypeVariables") REGISTER_OP("InPolymorphicTwice") .Input("a: N * T") .Input("b: M * T") - .Attr("T: type") + .Attr("T: type = DT_INT32") .Attr("N: int >= 0") .Attr("M: int >= 0") .SetShapeFn(shape_inference::UnknownShape);