Remove OpKernel* templating from AddNOp.
PiperOrigin-RevId: 347733503 Change-Id: I12882f5e7ce32b5d7c553459a5ce3bd7d4967eea
This commit is contained in:
parent
c3900bfd70
commit
029b72e374
@ -3230,6 +3230,7 @@ MATH_DEPS = [
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:math_grad",
|
||||
"//tensorflow/core/framework:bounds_check",
|
||||
"//tensorflow/core/framework:op_requires",
|
||||
"//third_party/eigen3",
|
||||
]
|
||||
|
||||
@ -3281,23 +3282,6 @@ tf_kernel_library(
|
||||
deps = MATH_DEPS,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "aggregate_ops_headers",
|
||||
hdrs = [
|
||||
"aggregate_ops.h",
|
||||
"aggregate_ops_cpu.h",
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "argmax_op",
|
||||
prefix = "argmax_op",
|
||||
|
@ -19,21 +19,238 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/kernels/aggregate_ops.h"
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/op_requires.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/kernels/aggregate_ops_cpu.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename Device, typename T>
|
||||
class AddNOp : public OpKernel {
|
||||
public:
|
||||
explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
if (!ctx->ValidateInputsAreSameShape(this)) return;
|
||||
|
||||
const Tensor& input0 = ctx->input(0);
|
||||
const int num = ctx->num_inputs();
|
||||
|
||||
if (num == 1) {
|
||||
ctx->set_output(0, input0);
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to forward and accumulate the result in one of the input buffers.
|
||||
int reused_input = -1;
|
||||
gtl::InlinedVector<int, 8> input_indices(num);
|
||||
std::iota(input_indices.begin(), input_indices.end(), 0);
|
||||
Tensor* output = nullptr;
|
||||
for (int input_idx = 0; input_idx < num; ++input_idx) {
|
||||
if (ctx->forward_input_to_output_with_shape(input_idx, 0, input0.shape(),
|
||||
&output)) {
|
||||
reused_input = input_idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (reused_input == -1) {
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input0.shape(), &output));
|
||||
} else if (reused_input > 0) {
|
||||
// Move the forwarded buffer to the front so we don't double count
|
||||
// anything if there are more than 8 inputs.
|
||||
input_indices[0] = reused_input;
|
||||
input_indices[reused_input] = 0;
|
||||
}
|
||||
auto To = output->flat<T>();
|
||||
|
||||
#define I(IDX) ctx->input(input_indices[IDX]).template flat<T>()
|
||||
|
||||
#if defined(__ANDROID_TYPES_SLIM__)
|
||||
// On Android by default,we only support additions of two arguments, so we
|
||||
// can reduce the number of template instantiations.
|
||||
OP_REQUIRES(ctx, num == 2,
|
||||
errors::InvalidArgument("Only additions of two arguments "
|
||||
"supported. Num inputs: ",
|
||||
num));
|
||||
functor::Add2Functor<Device, T> functor2;
|
||||
functor2(ctx->template eigen_device<Device>(), To, I(0), I(1));
|
||||
#else
|
||||
static const int kWidth = 8;
|
||||
int r = num % kWidth;
|
||||
|
||||
switch (r) {
|
||||
case 2: {
|
||||
functor::Add2Functor<Device, T> functor2;
|
||||
functor2(ctx->template eigen_device<Device>(), To, I(0), I(1));
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
functor::Add3Functor<Device, T> functor3;
|
||||
functor3(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2));
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
functor::Add4Functor<Device, T> functor4;
|
||||
functor4(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
|
||||
I(3));
|
||||
break;
|
||||
}
|
||||
case 5: {
|
||||
functor::Add5Functor<Device, T> functor5;
|
||||
functor5(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
|
||||
I(3), I(4));
|
||||
break;
|
||||
}
|
||||
case 6: {
|
||||
functor::Add6Functor<Device, T> functor6;
|
||||
functor6(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
|
||||
I(3), I(4), I(5));
|
||||
break;
|
||||
}
|
||||
case 7: {
|
||||
functor::Add7Functor<Device, T> functor7;
|
||||
functor7(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
|
||||
I(3), I(4), I(5), I(6));
|
||||
break;
|
||||
}
|
||||
case 0: {
|
||||
functor::Add8Functor<Device, T> functor8;
|
||||
functor8(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
|
||||
I(3), I(4), I(5), I(6), I(7));
|
||||
r = 8;
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
functor::Add9Functor<Device, T> functor9;
|
||||
functor9(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
|
||||
I(3), I(4), I(5), I(6), I(7), I(8));
|
||||
r = 9;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (; r < num; r += kWidth) {
|
||||
functor::Add8pFunctor<Device, T> functor8p;
|
||||
functor8p(ctx->template eigen_device<Device>(), To, I(r), I(r + 1),
|
||||
I(r + 2), I(r + 3), I(r + 4), I(r + 5), I(r + 6), I(r + 7));
|
||||
}
|
||||
#endif // defined(__ANDROID_TYPES_SLIM__)
|
||||
|
||||
#undef I
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device>
|
||||
class AddNOp<Device, Variant> : public OpKernel {
|
||||
public:
|
||||
explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
if (!ctx->ValidateInputsAreSameShape(this)) return;
|
||||
|
||||
const Tensor& input0 = ctx->input(0);
|
||||
const int num = ctx->num_inputs();
|
||||
|
||||
if (num == 1) {
|
||||
ctx->set_output(0, input0);
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < num; ++i) {
|
||||
// Step 1: ensure unary variants.
|
||||
OP_REQUIRES(
|
||||
ctx, ctx->input(i).dims() == 0,
|
||||
errors::InvalidArgument(
|
||||
"AddN of non-scalar Tensor with dtype=DT_VARIANT is not "
|
||||
"supported; inputs[",
|
||||
i, " has shape: ", ctx->input(i).shape().DebugString(), "."));
|
||||
}
|
||||
|
||||
// Step 2: Sum input variants in a tree-like structure using
|
||||
// BinaryOpVariants(ADD_VARIANT_BINARY_OP, ...)
|
||||
// For the output create a default-constructed variant object.
|
||||
//
|
||||
// Pairwise summation provides better numerical precision by
|
||||
// reducing round-off error:
|
||||
//
|
||||
// https://en.wikipedia.org/wiki/Pairwise_summation
|
||||
//
|
||||
// These two vectors are used to store and mark intermediate sums.
|
||||
gtl::InlinedVector<bool, 4> temp_filled(num, false);
|
||||
gtl::InlinedVector<Variant, 4> temp(num);
|
||||
|
||||
// Tree-based summation.
|
||||
int skip = 1;
|
||||
int n = num;
|
||||
while (skip < n) {
|
||||
int i = skip;
|
||||
while (i < n) {
|
||||
// TODO(ebrevdo, rmlarsen): Parallelize the pairwise summations in the
|
||||
// inner loop if the variants are "large".
|
||||
|
||||
// x[i - skip] += x[i]
|
||||
OP_REQUIRES_OK(ctx,
|
||||
AddVariantTo(ctx, i - skip, i, &temp, &temp_filled));
|
||||
// We won't use this index again, recover its memory.
|
||||
temp[i].clear();
|
||||
i += 2 * skip;
|
||||
}
|
||||
if (i == n) {
|
||||
// x[0] += x[i - skip]
|
||||
OP_REQUIRES_OK(ctx,
|
||||
AddVariantTo(ctx, 0, i - skip, &temp, &temp_filled));
|
||||
// We won't use this index again, recover its memory.
|
||||
temp[i - skip].clear();
|
||||
n -= skip;
|
||||
}
|
||||
skip *= 2;
|
||||
}
|
||||
|
||||
Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
||||
out.scalar<Variant>()() = std::move(temp[0]);
|
||||
ctx->set_output(0, out);
|
||||
}
|
||||
|
||||
private:
|
||||
// AddVariantTo efficiently performs:
|
||||
// temp[lhs_ix] <- array(lhs_ix) + array(rhs_ix)
|
||||
// where array(ix) := (temp_filled[ix]
|
||||
// ? temp[ix]
|
||||
// : ctx->input(ix).scalar<Variant>()())
|
||||
// This reduces (possibly expensive) copying of Variants from
|
||||
// the inputs into temp at the lowest levels of the summation tree.
|
||||
static inline Status AddVariantTo(OpKernelContext* ctx, const int lhs_ix,
|
||||
const int rhs_ix,
|
||||
gtl::InlinedVector<Variant, 4>* temp,
|
||||
gtl::InlinedVector<bool, 4>* temp_filled) {
|
||||
Variant tmp;
|
||||
if (temp_filled->at(lhs_ix)) tmp = std::move(temp->at(lhs_ix));
|
||||
const Variant& a = temp_filled->at(lhs_ix)
|
||||
? tmp
|
||||
: ctx->input(lhs_ix).template scalar<Variant>()();
|
||||
const Variant& b = temp_filled->at(rhs_ix)
|
||||
? temp->at(rhs_ix)
|
||||
: ctx->input(rhs_ix).template scalar<Variant>()();
|
||||
Variant* c = &temp->at(lhs_ix);
|
||||
TF_RETURN_IF_ERROR(
|
||||
BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP, a, b, c));
|
||||
temp_filled->at(lhs_ix) = true;
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_ADDN(type, dev) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("AddN").Device(DEVICE_##dev).TypeConstraint<type>("T"), \
|
||||
AddNOp<dev##Device, type, OpKernel, OpKernelConstruction, \
|
||||
OpKernelContext>)
|
||||
AddNOp<dev##Device, type>)
|
||||
|
||||
#define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU)
|
||||
|
||||
@ -54,17 +271,15 @@ TF_CALL_COMPLEX_TYPES(REGISTER_ADDN_GPU);
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
// registration requires all int32 inputs and outputs to be in host memory.
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("AddN")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<int32>("T")
|
||||
.HostMemory("inputs")
|
||||
.HostMemory("sum"),
|
||||
AddNOp<CPUDevice, int32, OpKernel, OpKernelConstruction, OpKernelContext>);
|
||||
REGISTER_KERNEL_BUILDER(Name("AddN")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<int32>("T")
|
||||
.HostMemory("inputs")
|
||||
.HostMemory("sum"),
|
||||
AddNOp<CPUDevice, int32>);
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
||||
#undef REGISTER_ADDN
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,11 +18,8 @@ limitations under the License.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "tensorflow/core/framework/op_requires.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
@ -223,226 +220,7 @@ struct Add9EigenImpl {
|
||||
out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8 + in9;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
template <typename Device, typename T, class OpKernelT,
|
||||
class OpKernelConstructionT, class OpKernelContextT>
|
||||
class AddNOp : public OpKernelT {
|
||||
public:
|
||||
explicit AddNOp(OpKernelConstructionT* context) : OpKernelT(context) {}
|
||||
|
||||
void Compute(OpKernelContextT* ctx) override {
|
||||
if (!ctx->ValidateInputsAreSameShape(this)) return;
|
||||
|
||||
const Tensor& input0 = ctx->input(0);
|
||||
const int num = ctx->num_inputs();
|
||||
|
||||
if (num == 1) {
|
||||
ctx->set_output(0, input0);
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to forward and accumulate the result in one of the input buffers.
|
||||
int reused_input = -1;
|
||||
gtl::InlinedVector<int, 8> input_indices(num);
|
||||
std::iota(input_indices.begin(), input_indices.end(), 0);
|
||||
Tensor* output = nullptr;
|
||||
for (int input_idx = 0; input_idx < num; ++input_idx) {
|
||||
if (ctx->forward_input_to_output_with_shape(input_idx, 0, input0.shape(),
|
||||
&output)) {
|
||||
reused_input = input_idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (reused_input == -1) {
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input0.shape(), &output));
|
||||
} else if (reused_input > 0) {
|
||||
// Move the forwarded buffer to the front so we don't double count
|
||||
// anything if there are more than 8 inputs.
|
||||
input_indices[0] = reused_input;
|
||||
input_indices[reused_input] = 0;
|
||||
}
|
||||
auto To = output->flat<T>();
|
||||
|
||||
#define I(IDX) ctx->input(input_indices[IDX]).template flat<T>()
|
||||
|
||||
#if defined(__ANDROID_TYPES_SLIM__)
|
||||
// On Android by default,we only support additions of two arguments, so we
|
||||
// can reduce the number of template instantiations.
|
||||
OP_REQUIRES(ctx, num == 2,
|
||||
errors::InvalidArgument("Only additions of two arguments "
|
||||
"supported. Num inputs: ",
|
||||
num));
|
||||
functor::Add2Functor<Device, T> functor2;
|
||||
functor2(ctx->template eigen_device<Device>(), To, I(0), I(1));
|
||||
#else
|
||||
static const int kWidth = 8;
|
||||
int r = num % kWidth;
|
||||
|
||||
switch (r) {
|
||||
case 2: {
|
||||
functor::Add2Functor<Device, T> functor2;
|
||||
functor2(ctx->template eigen_device<Device>(), To, I(0), I(1));
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
functor::Add3Functor<Device, T> functor3;
|
||||
functor3(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2));
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
functor::Add4Functor<Device, T> functor4;
|
||||
functor4(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
|
||||
I(3));
|
||||
break;
|
||||
}
|
||||
case 5: {
|
||||
functor::Add5Functor<Device, T> functor5;
|
||||
functor5(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
|
||||
I(3), I(4));
|
||||
break;
|
||||
}
|
||||
case 6: {
|
||||
functor::Add6Functor<Device, T> functor6;
|
||||
functor6(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
|
||||
I(3), I(4), I(5));
|
||||
break;
|
||||
}
|
||||
case 7: {
|
||||
functor::Add7Functor<Device, T> functor7;
|
||||
functor7(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
|
||||
I(3), I(4), I(5), I(6));
|
||||
break;
|
||||
}
|
||||
case 0: {
|
||||
functor::Add8Functor<Device, T> functor8;
|
||||
functor8(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
|
||||
I(3), I(4), I(5), I(6), I(7));
|
||||
r = 8;
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
functor::Add9Functor<Device, T> functor9;
|
||||
functor9(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
|
||||
I(3), I(4), I(5), I(6), I(7), I(8));
|
||||
r = 9;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (; r < num; r += kWidth) {
|
||||
functor::Add8pFunctor<Device, T> functor8p;
|
||||
functor8p(ctx->template eigen_device<Device>(), To, I(r), I(r + 1),
|
||||
I(r + 2), I(r + 3), I(r + 4), I(r + 5), I(r + 6), I(r + 7));
|
||||
}
|
||||
#endif // defined(__ANDROID_TYPES_SLIM__)
|
||||
|
||||
#undef I
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, class OpKernelT, class OpKernelConstructionT,
|
||||
class OpKernelContextT>
|
||||
class AddNOp<Device, Variant, OpKernelT, OpKernelConstructionT,
|
||||
OpKernelContextT> : public OpKernelT {
|
||||
public:
|
||||
explicit AddNOp(OpKernelConstructionT* context) : OpKernelT(context) {}
|
||||
|
||||
void Compute(OpKernelContextT* ctx) override {
|
||||
if (!ctx->ValidateInputsAreSameShape(this)) return;
|
||||
|
||||
const Tensor& input0 = ctx->input(0);
|
||||
const int num = ctx->num_inputs();
|
||||
|
||||
if (num == 1) {
|
||||
ctx->set_output(0, input0);
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < num; ++i) {
|
||||
// Step 1: ensure unary variants.
|
||||
OP_REQUIRES(
|
||||
ctx, ctx->input(i).dims() == 0,
|
||||
errors::InvalidArgument(
|
||||
"AddN of non-scalar Tensor with dtype=DT_VARIANT is not "
|
||||
"supported; inputs[",
|
||||
i, " has shape: ", ctx->input(i).shape().DebugString(), "."));
|
||||
}
|
||||
|
||||
// Step 2: Sum input variants in a tree-like structure using
|
||||
// BinaryOpVariants(ADD_VARIANT_BINARY_OP, ...)
|
||||
// For the output create a default-constructed variant object.
|
||||
//
|
||||
// Pairwise summation provides better numerical precision by
|
||||
// reducing round-off error:
|
||||
//
|
||||
// https://en.wikipedia.org/wiki/Pairwise_summation
|
||||
//
|
||||
// These two vectors are used to store and mark intermediate sums.
|
||||
gtl::InlinedVector<bool, 4> temp_filled(num, false);
|
||||
gtl::InlinedVector<Variant, 4> temp(num);
|
||||
|
||||
// Tree-based summation.
|
||||
int skip = 1;
|
||||
int n = num;
|
||||
while (skip < n) {
|
||||
int i = skip;
|
||||
while (i < n) {
|
||||
// TODO(ebrevdo, rmlarsen): Parallelize the pairwise summations in the
|
||||
// inner loop if the variants are "large".
|
||||
|
||||
// x[i - skip] += x[i]
|
||||
OP_REQUIRES_OK(ctx,
|
||||
AddVariantTo(ctx, i - skip, i, &temp, &temp_filled));
|
||||
// We won't use this index again, recover its memory.
|
||||
temp[i].clear();
|
||||
i += 2 * skip;
|
||||
}
|
||||
if (i == n) {
|
||||
// x[0] += x[i - skip]
|
||||
OP_REQUIRES_OK(ctx,
|
||||
AddVariantTo(ctx, 0, i - skip, &temp, &temp_filled));
|
||||
// We won't use this index again, recover its memory.
|
||||
temp[i - skip].clear();
|
||||
n -= skip;
|
||||
}
|
||||
skip *= 2;
|
||||
}
|
||||
|
||||
Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
||||
out.scalar<Variant>()() = std::move(temp[0]);
|
||||
ctx->set_output(0, out);
|
||||
}
|
||||
|
||||
private:
|
||||
// AddVariantTo efficiently performs:
|
||||
// temp[lhs_ix] <- array(lhs_ix) + array(rhs_ix)
|
||||
// where array(ix) := (temp_filled[ix]
|
||||
// ? temp[ix]
|
||||
// : ctx->input(ix).scalar<Variant>()())
|
||||
// This reduces (possibly expensive) copying of Variants from
|
||||
// the inputs into temp at the lowest levels of the summation tree.
|
||||
static inline Status AddVariantTo(OpKernelContextT* ctx, const int lhs_ix,
|
||||
const int rhs_ix,
|
||||
gtl::InlinedVector<Variant, 4>* temp,
|
||||
gtl::InlinedVector<bool, 4>* temp_filled) {
|
||||
Variant tmp;
|
||||
if (temp_filled->at(lhs_ix)) tmp = std::move(temp->at(lhs_ix));
|
||||
const Variant& a = temp_filled->at(lhs_ix)
|
||||
? tmp
|
||||
: ctx->input(lhs_ix).template scalar<Variant>()();
|
||||
const Variant& b = temp_filled->at(rhs_ix)
|
||||
? temp->at(rhs_ix)
|
||||
: ctx->input(rhs_ix).template scalar<Variant>()();
|
||||
Variant* c = &temp->at(lhs_ix);
|
||||
TF_RETURN_IF_ERROR(
|
||||
BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP, a, b, c));
|
||||
temp_filled->at(lhs_ix) = true;
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
|
||||
|
Loading…
Reference in New Issue
Block a user