diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index e6eb2d9c6d4..402d24d5b5b 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -443,6 +443,38 @@ struct BatchSelectFunctor { d.parallelFor(batch, cost, work); } }; + +template +struct BCastSelectFunctorBase { + void operator()(const Device& d, + typename TTypes::Tensor output_tensor, + typename TTypes::ConstTensor cond_tensor, + typename TTypes::ConstTensor then_tensor, + typename TTypes::ConstTensor else_tensor, + typename Eigen::array cond_bcast, + typename Eigen::array then_bcast, + typename Eigen::array else_bcast) { + output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) + .select(then_tensor.broadcast(then_bcast), + else_tensor.broadcast(else_bcast)); + } +}; + +template +struct BCastSelectFunctor + : BCastSelectFunctorBase {}; + +#ifdef TENSORFLOW_USE_SYCL +template +struct BatchSelectFunctor + : BatchSelectFunctorBase {}; + +template +struct BCastSelectFunctor + : BCastSelectFunctorBase {}; + +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 610ff07a4dc..cc8eba69555 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -1217,11 +1217,7 @@ struct BCastSelectFunctor { typename TTypes::ConstTensor else_tensor, typename Eigen::array cond_bcast, typename Eigen::array then_bcast, - typename Eigen::array else_bcast) { - output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) - .select(then_tensor.broadcast(then_bcast), - else_tensor.broadcast(else_bcast)); - } + typename Eigen::array else_bcast); }; } // end namespace functor