Automatically cast indices of IndexedSlices returned by cond_v2 from int32 to
int64 to match the IndexedSlices returned from the other branch. PiperOrigin-RevId: 247433509
This commit is contained in:
parent
078f2ab6da
commit
e00241511d
@ -448,12 +448,8 @@ class ControlFlowTest(test.TestCase):
|
||||
values = constant_op.constant(10)
|
||||
indices = constant_op.constant(0)
|
||||
x = ops.IndexedSlices(values, indices)
|
||||
v1_msg = "The two structures don't have the same nested structure"
|
||||
v2_msg = ("true_fn and false_fn arguments to tf.cond must have the same "
|
||||
"number, type, and overall structure of return values.")
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
|
||||
TypeError, "Cannot reconcile tf.cond 0-th outputs"):
|
||||
control_flow_ops.cond(
|
||||
constant_op.constant(True),
|
||||
lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices),
|
||||
@ -516,7 +512,6 @@ class ControlFlowTest(test.TestCase):
|
||||
self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0])
|
||||
self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0])
|
||||
|
||||
@test_util.disable_control_flow_v2("b/113293074")
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testCondIndexedSlicesDifferentTypes(self):
|
||||
with self.cached_session():
|
||||
|
@ -185,6 +185,7 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
|
||||
A list of Tensors which are the outputs of the If op. Does not include added
|
||||
intermediate outputs.
|
||||
"""
|
||||
_make_indexed_slices_indices_types_match(true_graph, false_graph)
|
||||
_check_same_outputs(true_graph, false_graph)
|
||||
|
||||
# Add inputs to true_graph and false_graph to make them match. Note that
|
||||
@ -522,6 +523,63 @@ def _make_output_composite_tensors_match(true_graph, false_graph):
|
||||
false_graph.outputs = func_graph_module.flatten(false_outputs)
|
||||
|
||||
|
||||
def _make_indexed_slices_indices_types_match(true_graph, false_graph):
|
||||
"""Match dtype of IndexedSlices.indices in outputs of {true|false}_graphs."""
|
||||
indexed_slice_indices = []
|
||||
current_index = 0
|
||||
true_outputs_flat_with_composites = nest.flatten(
|
||||
true_graph.structured_outputs, expand_composites=False)
|
||||
false_outputs_flat_with_composites = nest.flatten(
|
||||
false_graph.structured_outputs, expand_composites=False)
|
||||
# Store indices of IndexedSlices.indices in `indexed_slice_indices`.
|
||||
for idx, (true_out, false_out) in enumerate(
|
||||
zip(true_outputs_flat_with_composites,
|
||||
false_outputs_flat_with_composites)):
|
||||
if isinstance(true_out, ops.IndexedSlices) != isinstance(
|
||||
false_out, ops.IndexedSlices):
|
||||
raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n"
|
||||
" true_fn returned: %s\n"
|
||||
" false_fn returned: %s" % (idx, true_out, false_out))
|
||||
if isinstance(true_out, ops.IndexedSlices):
|
||||
# indices is the second component of the composite tensor.
|
||||
indexed_slice_indices.append(current_index + 1)
|
||||
if nest.is_sequence_or_composite(true_out):
|
||||
current_index += len(nest.flatten(true_out, expand_composites=True))
|
||||
else:
|
||||
current_index += 1
|
||||
|
||||
if not indexed_slice_indices:
|
||||
return
|
||||
|
||||
if current_index != len(true_graph.outputs):
|
||||
raise ValueError("Insufficient elements in true_graph.outputs.\n"
|
||||
"Expected: %i\n"
|
||||
"Actual: %i" % (current_index, len(true_graph.outputs)))
|
||||
|
||||
# Cast indices with mismatching types to int64.
|
||||
for index in indexed_slice_indices:
|
||||
if true_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
|
||||
"Found: %s" % str(true_graph.outputs[index].dtype))
|
||||
if false_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
|
||||
"Found: %s" % str(false_graph.outputs[index].dtype))
|
||||
if true_graph.outputs[index].dtype != false_graph.outputs[index].dtype:
|
||||
if false_graph.outputs[index].dtype == dtypes.int32:
|
||||
with false_graph.as_default():
|
||||
false_graph.outputs[index] = math_ops.cast(false_graph.outputs[index],
|
||||
dtypes.int64)
|
||||
else:
|
||||
with true_graph.as_default():
|
||||
true_graph.outputs[index] = math_ops.cast(true_graph.outputs[index],
|
||||
dtypes.int64)
|
||||
|
||||
true_graph.structured_outputs = func_graph_module.pack_sequence_as(
|
||||
true_graph.structured_outputs, true_graph.outputs)
|
||||
false_graph.structured_outputs = func_graph_module.pack_sequence_as(
|
||||
false_graph.structured_outputs, false_graph.outputs)
|
||||
|
||||
|
||||
def _wrap_intermediates(func_graph, intermediates):
|
||||
with func_graph.as_default():
|
||||
return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]
|
||||
|
Loading…
Reference in New Issue
Block a user