From ec14e1b42947fc29abd8f677f8bdfde15c754ac2 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Fri, 18 Sep 2020 21:16:05 -0700 Subject: [PATCH] Fix undefined behavior in `tf.raw_ops.Switch` in eager mode. PiperOrigin-RevId: 332578058 Change-Id: I9727571d2f21476b10d8aa27c1b7176564b76ac9 --- tensorflow/core/common_runtime/eager/kernel_and_device.cc | 7 ++++++- .../python/kernel_tests/control_flow_ops_py_test.py | 8 ++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index d3f6cb2a080..84c52bf3d54 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -336,7 +336,12 @@ Status KernelAndDeviceOp::Run(ScopedStepContainer* step_container, if (outputs != nullptr) { outputs->clear(); for (int i = 0; i < context.num_outputs(); ++i) { - outputs->push_back(Tensor(*context.mutable_output(i))); + const auto* output_tensor = context.mutable_output(i); + if (output_tensor != nullptr) { + outputs->push_back(Tensor(*output_tensor)); + } else { + outputs->push_back(Tensor()); + } } } if (stats != nullptr) { diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 9acaec4f039..ba92cac3c77 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -4505,6 +4505,14 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): result = control_flow_ops.merge([v_f, v_t]) self.evaluate(result) + def testSwitchEagerMode(self): + if not context.executing_eagerly(): + return + input_data = [1, 2, 3, 4] + vf, vt = control_flow_ops.switch(input_data, False) + self.assertAllEqual(vf, input_data) + self.assertAllEqual(vt, []) + @test_util.run_deprecated_v1 def testQIntArgAndRet(self):