Avoid copying Tensor when Select op takes scalar cond input (CPU).
PiperOrigin-RevId: 208868429
This commit is contained in:
parent
ec5f4771e4
commit
bc646fd576
@ -33,6 +33,11 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||||||
typedef Eigen::SyclDevice SYCLDevice;
|
typedef Eigen::SyclDevice SYCLDevice;
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
template <typename Device, typename T>
|
||||||
|
struct SelectScalarHandler;
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class SelectOp : public OpKernel {
|
class SelectOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
@ -131,16 +136,8 @@ class SelectOp : public OpKernel {
|
|||||||
then->shape().DebugString(), " vs. ",
|
then->shape().DebugString(), " vs. ",
|
||||||
else_->shape().DebugString()));
|
else_->shape().DebugString()));
|
||||||
|
|
||||||
Tensor* output = nullptr;
|
functor::SelectScalarHandler<Device, T> handler;
|
||||||
OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
|
handler(ctx, cond, then, else_);
|
||||||
{"t", "e"}, "output", then->shape(), &output));
|
|
||||||
|
|
||||||
if (output->NumElements() > 0) {
|
|
||||||
functor::SelectScalarFunctor<Device, T> func;
|
|
||||||
TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
|
|
||||||
func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
|
|
||||||
then->flat<T>(), else_->flat<T>());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -208,6 +205,40 @@ template <typename T>
|
|||||||
struct SelectFunctor<SYCLDevice, T> : SelectFunctorBase<SYCLDevice, T> {};
|
struct SelectFunctor<SYCLDevice, T> : SelectFunctorBase<SYCLDevice, T> {};
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
struct SelectScalarHandler {
|
||||||
|
void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
|
||||||
|
const Tensor* else_) {
|
||||||
|
Tensor* output = nullptr;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
|
||||||
|
{"t", "e"}, "output", then->shape(), &output));
|
||||||
|
|
||||||
|
if (output->NumElements() > 0) {
|
||||||
|
functor::SelectScalarFunctor<Device, T> func;
|
||||||
|
TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
|
||||||
|
func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
|
||||||
|
then->flat<T>(), else_->flat<T>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Specilization for CPU device. Forward input to output depending on the `cond`
|
||||||
|
// value.
|
||||||
|
// TODO(sjhwang): Consider specializing for GPUDevice as well by using
|
||||||
|
// GPUDevice::memcpyDeviceToHost() to fetch bool value.
|
||||||
|
template <typename T>
|
||||||
|
struct SelectScalarHandler<CPUDevice, T> {
|
||||||
|
void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
|
||||||
|
const Tensor* else_) {
|
||||||
|
if (cond->scalar<bool>()()) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->set_output("output", *then));
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->set_output("output", *else_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
struct SelectScalarFunctorBase {
|
struct SelectScalarFunctorBase {
|
||||||
void operator()(const Device& d, typename TTypes<T>::Flat out,
|
void operator()(const Device& d, typename TTypes<T>::Flat out,
|
||||||
@ -218,11 +249,6 @@ struct SelectScalarFunctorBase {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// CPU Specializations of Select functors with scalar
|
|
||||||
template <typename T>
|
|
||||||
struct SelectScalarFunctor<CPUDevice, T>
|
|
||||||
: SelectScalarFunctorBase<CPUDevice, T> {};
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct SelectScalarFunctor<SYCLDevice, T>
|
struct SelectScalarFunctor<SYCLDevice, T>
|
||||||
: SelectScalarFunctorBase<SYCLDevice, T> {};
|
: SelectScalarFunctorBase<SYCLDevice, T> {};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user