Avoid copying Tensor when Select op takes scalar cond input (CPU).
PiperOrigin-RevId: 208868429
This commit is contained in:
parent
ec5f4771e4
commit
bc646fd576
@ -33,6 +33,11 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
typedef Eigen::SyclDevice SYCLDevice;
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
namespace functor {
|
||||
template <typename Device, typename T>
|
||||
struct SelectScalarHandler;
|
||||
} // namespace functor
|
||||
|
||||
template <typename Device, typename T>
|
||||
class SelectOp : public OpKernel {
|
||||
public:
|
||||
@ -131,16 +136,8 @@ class SelectOp : public OpKernel {
|
||||
then->shape().DebugString(), " vs. ",
|
||||
else_->shape().DebugString()));
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
|
||||
{"t", "e"}, "output", then->shape(), &output));
|
||||
|
||||
if (output->NumElements() > 0) {
|
||||
functor::SelectScalarFunctor<Device, T> func;
|
||||
TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
|
||||
func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
|
||||
then->flat<T>(), else_->flat<T>());
|
||||
}
|
||||
functor::SelectScalarHandler<Device, T> handler;
|
||||
handler(ctx, cond, then, else_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -208,6 +205,40 @@ template <typename T>
|
||||
struct SelectFunctor<SYCLDevice, T> : SelectFunctorBase<SYCLDevice, T> {};
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct SelectScalarHandler {
|
||||
void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
|
||||
const Tensor* else_) {
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
|
||||
{"t", "e"}, "output", then->shape(), &output));
|
||||
|
||||
if (output->NumElements() > 0) {
|
||||
functor::SelectScalarFunctor<Device, T> func;
|
||||
TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
|
||||
func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
|
||||
then->flat<T>(), else_->flat<T>());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Specilization for CPU device. Forward input to output depending on the `cond`
|
||||
// value.
|
||||
// TODO(sjhwang): Consider specializing for GPUDevice as well by using
|
||||
// GPUDevice::memcpyDeviceToHost() to fetch bool value.
|
||||
template <typename T>
|
||||
struct SelectScalarHandler<CPUDevice, T> {
|
||||
void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
|
||||
const Tensor* else_) {
|
||||
if (cond->scalar<bool>()()) {
|
||||
OP_REQUIRES_OK(ctx, ctx->set_output("output", *then));
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, ctx->set_output("output", *else_));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
template <typename Device, typename T>
|
||||
struct SelectScalarFunctorBase {
|
||||
void operator()(const Device& d, typename TTypes<T>::Flat out,
|
||||
@ -218,11 +249,6 @@ struct SelectScalarFunctorBase {
|
||||
}
|
||||
};
|
||||
|
||||
// CPU Specializations of Select functors with scalar
|
||||
template <typename T>
|
||||
struct SelectScalarFunctor<CPUDevice, T>
|
||||
: SelectScalarFunctorBase<CPUDevice, T> {};
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
template <typename T>
|
||||
struct SelectScalarFunctor<SYCLDevice, T>
|
||||
: SelectScalarFunctorBase<SYCLDevice, T> {};
|
||||
|
Loading…
x
Reference in New Issue
Block a user