diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index 88a942648f0..a27f8893925 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -69,11 +69,12 @@ class FunctionalizeCondTest : public ::testing::Test { namespace { -// TODO(jpienaar): Re-enable. Disabling for ASAN failure. -TEST_F(FunctionalizeCondTest, DISABLED_ScopeIn) { +TEST_F(FunctionalizeCondTest, ScopeIn) { Tensor pred_tensor(DT_BOOL, TensorShape()); + pred_tensor.flat().setZero(); Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred"); Tensor val_tensor(DT_INT32, TensorShape()); + val_tensor.flat().setZero(); Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); Node* s = test::graph::Switch(graph_.get(), val, pred); @@ -112,11 +113,12 @@ TEST_F(FunctionalizeCondTest, DISABLED_ScopeIn) { } } -// TODO(jpienaar): Re-enable. Disabling for ASAN failure. -TEST_F(FunctionalizeCondTest, DISABLED_JoinCondStates) { +TEST_F(FunctionalizeCondTest, JoinCondStates) { Tensor pred_tensor(DT_BOOL, TensorShape()); + pred_tensor.flat().setZero(); Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred"); Tensor val_tensor(DT_INT32, TensorShape()); + val_tensor.flat().setZero(); Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); Node* s = test::graph::Switch(graph_.get(), val, pred);