diff --git a/tensorflow/core/kernels/where_op.cc b/tensorflow/core/kernels/where_op.cc index 318894bfce4..598cb526d77 100644 --- a/tensorflow/core/kernels/where_op.cc +++ b/tensorflow/core/kernels/where_op.cc @@ -75,7 +75,7 @@ template struct NumTrue { static Status Compute(OpKernelContext* ctx, const CPUDevice& d, typename TTypes::ConstFlat input, - TTypes::Scalar num_true) { + TTypes::UnalignedScalar num_true) { num_true() = CountAccumulator(input.data(), input.data() + input.size()); return Status::OK(); } @@ -140,18 +140,14 @@ class WhereCPUOp : public OpKernel { const int input_dims = input.dims(); - Tensor num_true; - AllocatorAttributes attr; - attr.set_on_host(true); - OP_REQUIRES_OK(context, context->allocate_temp(DT_INT64, TensorShape({}), - &num_true, attr)); - auto num_true_t = num_true.scalar(); + int64 num_true; + TTypes::UnalignedScalar num_true_t(&num_true); Status s = functor::NumTrue::Compute( context, context->eigen_device(), input.flat(), num_true_t); OP_REQUIRES_OK(context, s); - TensorShape output_shape({num_true_t(), input_dims}); + TensorShape output_shape({num_true, input_dims}); Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); @@ -216,7 +212,7 @@ namespace functor { template <> \ Status NumTrue::Compute( \ OpKernelContext* ctx, const GPUDevice& d, TTypes::ConstFlat input, \ - TTypes::Scalar num_true); \ + TTypes::UnalignedScalar num_true); \ extern template struct NumTrue #define DECLARE_GPU_NUMTRUE_TYPE(T) \ @@ -287,8 +283,8 @@ class WhereGPUOp : public AsyncOpKernel { context->allocate_temp(DataTypeToEnum::v(), TensorShape({}), &num_true), done); - - auto num_true_t = num_true.scalar(); + typename TTypes::UnalignedScalar num_true_t( + num_true.scalar().data()); se::DeviceMemoryBase num_true_ptr(static_cast(num_true_t.data())); // Push kernel to stream to get number of true elements. diff --git a/tensorflow/core/kernels/where_op.h b/tensorflow/core/kernels/where_op.h index 7297d37ffb8..58d38139f3a 100644 --- a/tensorflow/core/kernels/where_op.h +++ b/tensorflow/core/kernels/where_op.h @@ -41,7 +41,7 @@ struct NumTrue { EIGEN_ALWAYS_INLINE static Status Compute( OpKernelContext* ctx, const Device& d, typename TTypes::ConstFlat input, - typename TTypes::Scalar num_true); + typename TTypes::UnalignedScalar num_true); }; template diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h index 3795733f959..f13f504c1d7 100644 --- a/tensorflow/core/kernels/where_op_gpu.cu.h +++ b/tensorflow/core/kernels/where_op_gpu.cu.h @@ -149,7 +149,7 @@ struct NumTrue { EIGEN_ALWAYS_INLINE static Status Compute( OpKernelContext* ctx, const GPUDevice& d, typename TTypes::ConstFlat input, - typename TTypes::Scalar num_true) { + typename TTypes::UnalignedScalar num_true) { const auto& cu_stream = GetGpuStream(ctx); std::size_t temp_storage_bytes = 0;