For ops with an input for the form "a: N * T", if N=0 and there's no other information about what type T should resolve to, then set T to its default value (if one is present).
PiperOrigin-RevId: 271996738
This commit is contained in:
parent
096554d255
commit
4d35b20572
tensorflow/python
@ -235,6 +235,8 @@ def make_tensor(v, arg_name):
|
||||
|
||||
def args_to_matching_eager(l, ctx, default_dtype=None):
|
||||
"""Convert sequence `l` to eager same-type Tensors."""
|
||||
if (not l) and (default_dtype is not None):
|
||||
return default_dtype, [] # List is empty; assume default dtype.
|
||||
EagerTensor = ops.EagerTensor # pylint: disable=invalid-name
|
||||
for x in l:
|
||||
if not isinstance(x, EagerTensor):
|
||||
|
@ -326,16 +326,18 @@ class OpsTest(test_util.TensorFlowTestCase):
|
||||
# Uses default
|
||||
ctx = context.context()
|
||||
t, r = execute.args_to_matching_eager([[3, 4]], ctx, dtypes.int32)
|
||||
self.assertEquals(t, dtypes.int32)
|
||||
self.assertEquals(r[0].dtype, dtypes.int32)
|
||||
self.assertEqual(t, dtypes.int32)
|
||||
self.assertEqual(r[0].dtype, dtypes.int32)
|
||||
t, r = execute.args_to_matching_eager([[3, 4]], ctx, dtypes.int64)
|
||||
self.assertEquals(t, dtypes.int64)
|
||||
self.assertEquals(r[0].dtype, dtypes.int64)
|
||||
self.assertEqual(t, dtypes.int64)
|
||||
self.assertEqual(r[0].dtype, dtypes.int64)
|
||||
t, r = execute.args_to_matching_eager([], ctx, dtypes.int64)
|
||||
self.assertEqual(t, dtypes.int64)
|
||||
# Doesn't use default
|
||||
t, r = execute.args_to_matching_eager(
|
||||
[['string', 'arg']], ctx, dtypes.int32)
|
||||
self.assertEquals(t, dtypes.string)
|
||||
self.assertEquals(r[0].dtype, dtypes.string)
|
||||
self.assertEqual(t, dtypes.string)
|
||||
self.assertEqual(r[0].dtype, dtypes.string)
|
||||
|
||||
def testFlattenLayer(self):
|
||||
flatten_layer = core.Flatten()
|
||||
|
@ -545,15 +545,19 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
|
||||
# <number-attr> * <type-attr> case, where we are now setting
|
||||
# the <type-attr> based on this input
|
||||
if not base_types:
|
||||
raise TypeError(
|
||||
"Don't know how to infer type variable from empty input "
|
||||
"list passed to input '%s' of '%s' Op." %
|
||||
(input_name, op_type_name))
|
||||
attrs[input_arg.type_attr] = base_types[0]
|
||||
inferred_from[input_arg.type_attr] = input_name
|
||||
type_attr = _Attr(op_def, input_arg.type_attr)
|
||||
_SatisfiesTypeConstraint(base_types[0], type_attr,
|
||||
param_name=input_name)
|
||||
# If it's in default_type_attr_map, then wait to set it
|
||||
# (in "process remaining attrs", below).
|
||||
if input_arg.type_attr not in default_type_attr_map:
|
||||
raise TypeError(
|
||||
"Don't know how to infer type variable from empty input "
|
||||
"list passed to input '%s' of '%s' Op." %
|
||||
(input_name, op_type_name))
|
||||
else:
|
||||
attrs[input_arg.type_attr] = base_types[0]
|
||||
inferred_from[input_arg.type_attr] = input_name
|
||||
type_attr = _Attr(op_def, input_arg.type_attr)
|
||||
_SatisfiesTypeConstraint(base_types[0], type_attr,
|
||||
param_name=input_name)
|
||||
elif input_arg.type_attr:
|
||||
# <type-attr>
|
||||
attr_value = base_types[0]
|
||||
@ -620,6 +624,9 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
|
||||
# Attrs whose names match Python keywords have an extra '_'
|
||||
# appended, so we must check for that as well.
|
||||
attrs[attr.name] = keywords.pop(attr.name + "_")
|
||||
elif attr.name in default_type_attr_map:
|
||||
attrs[attr.name] = default_type_attr_map[attr.name]
|
||||
inferred_from.setdefault(attr.name, "Default in OpDef")
|
||||
else:
|
||||
raise TypeError("No argument for attr " + attr.name)
|
||||
|
||||
|
@ -1048,11 +1048,33 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
|
||||
attr { key: 'M' value { i: 0 } }
|
||||
""", op.node_def)
|
||||
|
||||
with self.assertRaises(TypeError) as cm:
|
||||
op_def_library.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5])
|
||||
self.assertEqual(str(cm.exception),
|
||||
"Don't know how to infer type variable from empty input "
|
||||
"list passed to input 'a' of 'InPolymorphicTwice' Op.")
|
||||
op = op_def_library.apply_op(
|
||||
"InPolymorphicTwice", a=[], b=[3, 4], name="p")
|
||||
self.assertProtoEquals("""
|
||||
name: 'p' op: 'InPolymorphicTwice' input: 'p/b_0' input: 'p/b_1'
|
||||
attr { key: 'T' value { type: DT_INT32 } }
|
||||
attr { key: 'N' value { i: 0 } }
|
||||
attr { key: 'M' value { i: 2 } }
|
||||
""", op.node_def)
|
||||
|
||||
op = op_def_library.apply_op(
|
||||
"InPolymorphicTwice", a=[], b=[3.0, 4.0], name="q")
|
||||
self.assertProtoEquals("""
|
||||
name: 'q' op: 'InPolymorphicTwice' input: 'q/b_0' input: 'q/b_1'
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'N' value { i: 0 } }
|
||||
attr { key: 'M' value { i: 2 } }
|
||||
""", op.node_def)
|
||||
|
||||
# Empty input lists: assume defaut type for T.
|
||||
op = op_def_library.apply_op(
|
||||
"InPolymorphicTwice", a=[], b=[], name="r")
|
||||
self.assertProtoEquals("""
|
||||
name: 'r' op: 'InPolymorphicTwice'
|
||||
attr { key: 'T' value { type: DT_INT32 } }
|
||||
attr { key: 'N' value { i: 0 } }
|
||||
attr { key: 'M' value { i: 0 } }
|
||||
""", op.node_def)
|
||||
|
||||
with self.assertRaises(TypeError) as cm:
|
||||
op_def_library.apply_op(
|
||||
|
@ -598,7 +598,7 @@ REGISTER_OP("NInTwoTypeVariables")
|
||||
REGISTER_OP("InPolymorphicTwice")
|
||||
.Input("a: N * T")
|
||||
.Input("b: M * T")
|
||||
.Attr("T: type")
|
||||
.Attr("T: type = DT_INT32")
|
||||
.Attr("N: int >= 0")
|
||||
.Attr("M: int >= 0")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
Loading…
Reference in New Issue
Block a user