diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index 749e7ea792b..e6eb2d9c6d4 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -162,15 +162,15 @@ class SelectV2Op : public OpKernel { // 2-ary broadcast. // Combine `then` and `else`. - BCast elem_bcast(BCast::FromShape(then->shape()), - BCast::FromShape(else_->shape()), false); - OP_REQUIRES(ctx, elem_bcast.IsValid(), + BCast then_else_bcast(BCast::FromShape(then->shape()), + BCast::FromShape(else_->shape()), false); + OP_REQUIRES(ctx, then_else_bcast.IsValid(), errors::InvalidArgument( "then ", then->shape().DebugString(), " and else ", else_->shape().DebugString(), " must be broadcastable")); // Combine `cond` with `then` and `else`. BCast bcast(BCast::FromShape(cond->shape()), - BCast::FromShape(BCast::ToShape(elem_bcast.output_shape())), + BCast::FromShape(BCast::ToShape(then_else_bcast.output_shape())), false); OP_REQUIRES(ctx, bcast.IsValid(), errors::InvalidArgument( @@ -443,46 +443,6 @@ struct BatchSelectFunctor { d.parallelFor(batch, cost, work); } }; - -template -struct BCastSelectFunctor { - void operator()(const CPUDevice& d, - typename TTypes::Tensor output_tensor, - typename TTypes::ConstTensor cond_tensor, - typename TTypes::ConstTensor then_tensor, - typename TTypes::ConstTensor else_tensor, - typename Eigen::array cond_bcast, - typename Eigen::array then_bcast, - typename Eigen::array else_bcast) { - output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) - .select(then_tensor.broadcast(then_bcast), - else_tensor.broadcast(else_bcast)); - } -}; - -#ifdef TENSORFLOW_USE_SYCL -template -struct BatchSelectFunctor - : BatchSelectFunctorBase {}; - -template -struct BCastSelectFunctor { - void operator()(const SYCLDevice& d, - typename TTypes::Tensor output_tensor, - typename TTypes::ConstTensor cond_tensor, - typename TTypes::ConstTensor then_tensor, - typename TTypes::ConstTensor else_tensor, - typename Eigen::array cond_bcast, - typename Eigen::array then_bcast, - typename Eigen::array else_bcast) { - output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) - .select(then_tensor.broadcast(then_bcast), - else_tensor.broadcast(else_bcast)); - } -}; - -#endif // TENSORFLOW_USE_SYCL - } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index cc8eba69555..610ff07a4dc 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -1217,7 +1217,11 @@ struct BCastSelectFunctor { typename TTypes::ConstTensor else_tensor, typename Eigen::array cond_bcast, typename Eigen::array then_bcast, - typename Eigen::array else_bcast); + typename Eigen::array else_bcast) { + output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) + .select(then_tensor.broadcast(then_bcast), + else_tensor.broadcast(else_bcast)); + } }; } // end namespace functor