Automatically cast indices of IndexedSlices returned by cond_v2 from int32 to
int64 to match the IndexedSlices returned from the other branch. PiperOrigin-RevId: 247433509
This commit is contained in:
parent
078f2ab6da
commit
e00241511d
@ -448,12 +448,8 @@ class ControlFlowTest(test.TestCase):
|
|||||||
values = constant_op.constant(10)
|
values = constant_op.constant(10)
|
||||||
indices = constant_op.constant(0)
|
indices = constant_op.constant(0)
|
||||||
x = ops.IndexedSlices(values, indices)
|
x = ops.IndexedSlices(values, indices)
|
||||||
v1_msg = "The two structures don't have the same nested structure"
|
|
||||||
v2_msg = ("true_fn and false_fn arguments to tf.cond must have the same "
|
|
||||||
"number, type, and overall structure of return values.")
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
TypeError,
|
TypeError, "Cannot reconcile tf.cond 0-th outputs"):
|
||||||
v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
|
|
||||||
control_flow_ops.cond(
|
control_flow_ops.cond(
|
||||||
constant_op.constant(True),
|
constant_op.constant(True),
|
||||||
lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices),
|
lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices),
|
||||||
@ -516,7 +512,6 @@ class ControlFlowTest(test.TestCase):
|
|||||||
self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0])
|
self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0])
|
||||||
self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0])
|
self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0])
|
||||||
|
|
||||||
@test_util.disable_control_flow_v2("b/113293074")
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testCondIndexedSlicesDifferentTypes(self):
|
def testCondIndexedSlicesDifferentTypes(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
|
@ -185,6 +185,7 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
|
|||||||
A list of Tensors which are the outputs of the If op. Does not include added
|
A list of Tensors which are the outputs of the If op. Does not include added
|
||||||
intermediate outputs.
|
intermediate outputs.
|
||||||
"""
|
"""
|
||||||
|
_make_indexed_slices_indices_types_match(true_graph, false_graph)
|
||||||
_check_same_outputs(true_graph, false_graph)
|
_check_same_outputs(true_graph, false_graph)
|
||||||
|
|
||||||
# Add inputs to true_graph and false_graph to make them match. Note that
|
# Add inputs to true_graph and false_graph to make them match. Note that
|
||||||
@ -522,6 +523,63 @@ def _make_output_composite_tensors_match(true_graph, false_graph):
|
|||||||
false_graph.outputs = func_graph_module.flatten(false_outputs)
|
false_graph.outputs = func_graph_module.flatten(false_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_indexed_slices_indices_types_match(true_graph, false_graph):
|
||||||
|
"""Match dtype of IndexedSlices.indices in outputs of {true|false}_graphs."""
|
||||||
|
indexed_slice_indices = []
|
||||||
|
current_index = 0
|
||||||
|
true_outputs_flat_with_composites = nest.flatten(
|
||||||
|
true_graph.structured_outputs, expand_composites=False)
|
||||||
|
false_outputs_flat_with_composites = nest.flatten(
|
||||||
|
false_graph.structured_outputs, expand_composites=False)
|
||||||
|
# Store indices of IndexedSlices.indices in `indexed_slice_indices`.
|
||||||
|
for idx, (true_out, false_out) in enumerate(
|
||||||
|
zip(true_outputs_flat_with_composites,
|
||||||
|
false_outputs_flat_with_composites)):
|
||||||
|
if isinstance(true_out, ops.IndexedSlices) != isinstance(
|
||||||
|
false_out, ops.IndexedSlices):
|
||||||
|
raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n"
|
||||||
|
" true_fn returned: %s\n"
|
||||||
|
" false_fn returned: %s" % (idx, true_out, false_out))
|
||||||
|
if isinstance(true_out, ops.IndexedSlices):
|
||||||
|
# indices is the second component of the composite tensor.
|
||||||
|
indexed_slice_indices.append(current_index + 1)
|
||||||
|
if nest.is_sequence_or_composite(true_out):
|
||||||
|
current_index += len(nest.flatten(true_out, expand_composites=True))
|
||||||
|
else:
|
||||||
|
current_index += 1
|
||||||
|
|
||||||
|
if not indexed_slice_indices:
|
||||||
|
return
|
||||||
|
|
||||||
|
if current_index != len(true_graph.outputs):
|
||||||
|
raise ValueError("Insufficient elements in true_graph.outputs.\n"
|
||||||
|
"Expected: %i\n"
|
||||||
|
"Actual: %i" % (current_index, len(true_graph.outputs)))
|
||||||
|
|
||||||
|
# Cast indices with mismatching types to int64.
|
||||||
|
for index in indexed_slice_indices:
|
||||||
|
if true_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64):
|
||||||
|
raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
|
||||||
|
"Found: %s" % str(true_graph.outputs[index].dtype))
|
||||||
|
if false_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64):
|
||||||
|
raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
|
||||||
|
"Found: %s" % str(false_graph.outputs[index].dtype))
|
||||||
|
if true_graph.outputs[index].dtype != false_graph.outputs[index].dtype:
|
||||||
|
if false_graph.outputs[index].dtype == dtypes.int32:
|
||||||
|
with false_graph.as_default():
|
||||||
|
false_graph.outputs[index] = math_ops.cast(false_graph.outputs[index],
|
||||||
|
dtypes.int64)
|
||||||
|
else:
|
||||||
|
with true_graph.as_default():
|
||||||
|
true_graph.outputs[index] = math_ops.cast(true_graph.outputs[index],
|
||||||
|
dtypes.int64)
|
||||||
|
|
||||||
|
true_graph.structured_outputs = func_graph_module.pack_sequence_as(
|
||||||
|
true_graph.structured_outputs, true_graph.outputs)
|
||||||
|
false_graph.structured_outputs = func_graph_module.pack_sequence_as(
|
||||||
|
false_graph.structured_outputs, false_graph.outputs)
|
||||||
|
|
||||||
|
|
||||||
def _wrap_intermediates(func_graph, intermediates):
|
def _wrap_intermediates(func_graph, intermediates):
|
||||||
with func_graph.as_default():
|
with func_graph.as_default():
|
||||||
return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]
|
return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]
|
||||||
|
Loading…
Reference in New Issue
Block a user