diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 5d91e5f9c9c..feb10431d40 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -448,12 +448,8 @@ class ControlFlowTest(test.TestCase): values = constant_op.constant(10) indices = constant_op.constant(0) x = ops.IndexedSlices(values, indices) - v1_msg = "The two structures don't have the same nested structure" - v2_msg = ("true_fn and false_fn arguments to tf.cond must have the same " - "number, type, and overall structure of return values.") with self.assertRaisesRegexp( - TypeError, - v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg): + TypeError, "Cannot reconcile tf.cond 0-th outputs"): control_flow_ops.cond( constant_op.constant(True), lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices), @@ -516,7 +512,6 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0]) self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0]) - @test_util.disable_control_flow_v2("b/113293074") @test_util.run_v1_only("b/120545219") def testCondIndexedSlicesDifferentTypes(self): with self.cached_session(): diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index 38bcb84a2aa..5d661397b3d 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -185,6 +185,7 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ + _make_indexed_slices_indices_types_match(true_graph, false_graph) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that @@ -522,6 +523,63 @@ def _make_output_composite_tensors_match(true_graph, false_graph): false_graph.outputs = func_graph_module.flatten(false_outputs) +def _make_indexed_slices_indices_types_match(true_graph, false_graph): + """Match dtype of IndexedSlices.indices in outputs of {true|false}_graphs.""" + indexed_slice_indices = [] + current_index = 0 + true_outputs_flat_with_composites = nest.flatten( + true_graph.structured_outputs, expand_composites=False) + false_outputs_flat_with_composites = nest.flatten( + false_graph.structured_outputs, expand_composites=False) + # Store indices of IndexedSlices.indices in `indexed_slice_indices`. + for idx, (true_out, false_out) in enumerate( + zip(true_outputs_flat_with_composites, + false_outputs_flat_with_composites)): + if isinstance(true_out, ops.IndexedSlices) != isinstance( + false_out, ops.IndexedSlices): + raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n" + " true_fn returned: %s\n" + " false_fn returned: %s" % (idx, true_out, false_out)) + if isinstance(true_out, ops.IndexedSlices): + # indices is the second component of the composite tensor. + indexed_slice_indices.append(current_index + 1) + if nest.is_sequence_or_composite(true_out): + current_index += len(nest.flatten(true_out, expand_composites=True)) + else: + current_index += 1 + + if not indexed_slice_indices: + return + + if current_index != len(true_graph.outputs): + raise ValueError("Insufficient elements in true_graph.outputs.\n" + "Expected: %i\n" + "Actual: %i" % (current_index, len(true_graph.outputs))) + + # Cast indices with mismatching types to int64. + for index in indexed_slice_indices: + if true_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64): + raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " + "Found: %s" % str(true_graph.outputs[index].dtype)) + if false_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64): + raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " + "Found: %s" % str(false_graph.outputs[index].dtype)) + if true_graph.outputs[index].dtype != false_graph.outputs[index].dtype: + if false_graph.outputs[index].dtype == dtypes.int32: + with false_graph.as_default(): + false_graph.outputs[index] = math_ops.cast(false_graph.outputs[index], + dtypes.int64) + else: + with true_graph.as_default(): + true_graph.outputs[index] = math_ops.cast(true_graph.outputs[index], + dtypes.int64) + + true_graph.structured_outputs = func_graph_module.pack_sequence_as( + true_graph.structured_outputs, true_graph.outputs) + false_graph.structured_outputs = func_graph_module.pack_sequence_as( + false_graph.structured_outputs, false_graph.outputs) + + def _wrap_intermediates(func_graph, intermediates): with func_graph.as_default(): return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]