Avoid copying Tensor when Select op takes scalar cond input (CPU).

PiperOrigin-RevId: 208868429
This commit is contained in:
Sung Jin Hwang 2018-08-15 13:04:08 -07:00 committed by TensorFlower Gardener
parent ec5f4771e4
commit bc646fd576

View File

@ -33,6 +33,11 @@ typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
namespace functor {
template <typename Device, typename T>
struct SelectScalarHandler;
} // namespace functor
template <typename Device, typename T>
class SelectOp : public OpKernel {
public:
@ -131,16 +136,8 @@ class SelectOp : public OpKernel {
then->shape().DebugString(), " vs. ",
else_->shape().DebugString()));
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
{"t", "e"}, "output", then->shape(), &output));
if (output->NumElements() > 0) {
functor::SelectScalarFunctor<Device, T> func;
TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
then->flat<T>(), else_->flat<T>());
}
functor::SelectScalarHandler<Device, T> handler;
handler(ctx, cond, then, else_);
}
private:
@ -208,6 +205,40 @@ template <typename T>
struct SelectFunctor<SYCLDevice, T> : SelectFunctorBase<SYCLDevice, T> {};
#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
struct SelectScalarHandler {
void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
const Tensor* else_) {
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
{"t", "e"}, "output", then->shape(), &output));
if (output->NumElements() > 0) {
functor::SelectScalarFunctor<Device, T> func;
TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
then->flat<T>(), else_->flat<T>());
}
}
};
// Specilization for CPU device. Forward input to output depending on the `cond`
// value.
// TODO(sjhwang): Consider specializing for GPUDevice as well by using
// GPUDevice::memcpyDeviceToHost() to fetch bool value.
template <typename T>
struct SelectScalarHandler<CPUDevice, T> {
void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
const Tensor* else_) {
if (cond->scalar<bool>()()) {
OP_REQUIRES_OK(ctx, ctx->set_output("output", *then));
} else {
OP_REQUIRES_OK(ctx, ctx->set_output("output", *else_));
}
}
};
#ifdef TENSORFLOW_USE_SYCL
template <typename Device, typename T>
struct SelectScalarFunctorBase {
void operator()(const Device& d, typename TTypes<T>::Flat out,
@ -218,11 +249,6 @@ struct SelectScalarFunctorBase {
}
};
// CPU Specializations of Select functors with scalar
template <typename T>
struct SelectScalarFunctor<CPUDevice, T>
: SelectScalarFunctorBase<CPUDevice, T> {};
#ifdef TENSORFLOW_USE_SYCL
template <typename T>
struct SelectScalarFunctor<SYCLDevice, T>
: SelectScalarFunctorBase<SYCLDevice, T> {};