Fix undefined behavior in tf.raw_ops.Switch in eager mode.

PiperOrigin-RevId: 332578058
Change-Id: I9727571d2f21476b10d8aa27c1b7176564b76ac9
This commit is contained in:
Mihai Maruseac 2020-09-18 21:16:05 -07:00
parent 0615b26093
commit ec14e1b429
2 changed files with 14 additions and 1 deletions

View File

@ -336,7 +336,12 @@ Status KernelAndDeviceOp::Run(ScopedStepContainer* step_container,
if (outputs != nullptr) { if (outputs != nullptr) {
outputs->clear(); outputs->clear();
for (int i = 0; i < context.num_outputs(); ++i) { for (int i = 0; i < context.num_outputs(); ++i) {
outputs->push_back(Tensor(*context.mutable_output(i))); const auto* output_tensor = context.mutable_output(i);
if (output_tensor != nullptr) {
outputs->push_back(Tensor(*output_tensor));
} else {
outputs->push_back(Tensor());
}
} }
} }
if (stats != nullptr) { if (stats != nullptr) {

View File

@ -4505,6 +4505,14 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
result = control_flow_ops.merge([v_f, v_t]) result = control_flow_ops.merge([v_f, v_t])
self.evaluate(result) self.evaluate(result)
def testSwitchEagerMode(self):
if not context.executing_eagerly():
return
input_data = [1, 2, 3, 4]
vf, vt = control_flow_ops.switch(input_data, False)
self.assertAllEqual(vf, input_data)
self.assertAllEqual(vt, [])
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testQIntArgAndRet(self): def testQIntArgAndRet(self):