From 33cd7b88acefb737e45634b38c26c55783deafdf Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 2 May 2019 23:39:59 +0000 Subject: [PATCH] Fix GPU build failure Signed-off-by: Yong Tang --- tensorflow/core/kernels/cwise_op_select.cc | 32 ++++++++++++++++++++++ tensorflow/core/kernels/cwise_ops.h | 6 +--- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index e6eb2d9c6d4..402d24d5b5b 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -443,6 +443,38 @@ struct BatchSelectFunctor { d.parallelFor(batch, cost, work); } }; + +template +struct BCastSelectFunctorBase { + void operator()(const Device& d, + typename TTypes::Tensor output_tensor, + typename TTypes::ConstTensor cond_tensor, + typename TTypes::ConstTensor then_tensor, + typename TTypes::ConstTensor else_tensor, + typename Eigen::array cond_bcast, + typename Eigen::array then_bcast, + typename Eigen::array else_bcast) { + output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) + .select(then_tensor.broadcast(then_bcast), + else_tensor.broadcast(else_bcast)); + } +}; + +template +struct BCastSelectFunctor + : BCastSelectFunctorBase {}; + +#ifdef TENSORFLOW_USE_SYCL +template +struct BatchSelectFunctor + : BatchSelectFunctorBase {}; + +template +struct BCastSelectFunctor + : BCastSelectFunctorBase {}; + +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 610ff07a4dc..cc8eba69555 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -1217,11 +1217,7 @@ struct BCastSelectFunctor { typename TTypes::ConstTensor else_tensor, typename Eigen::array cond_bcast, typename Eigen::array then_bcast, - typename Eigen::array else_bcast) { - output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) - .select(then_tensor.broadcast(then_bcast), - else_tensor.broadcast(else_bcast)); - } + typename Eigen::array else_bcast); }; } // end namespace functor