Changes to scatter_nd ops
* Rewrite CPU impl to be single-threaded and use vectorization; avoids race conditions. Removes use of the generator. * Remove scatter_nd_mul and scatter_nd_div to reduce binary size until we figure out a better way to reduce the templating pain * Modify scatter_nd to add for repeated indices as opposed to update (this is the appropriate gradient for gather_nd, for example) * Clean up docstrings. Change: 138452341
This commit is contained in:
parent
aac685b720
commit
fd05b5ebc5
@ -146,43 +146,48 @@ class ScatterNdOp : public OpKernel {
|
||||
&num_updates, &slice_size);
|
||||
if (!c->status().ok()) return;
|
||||
|
||||
Tensor scratch;
|
||||
OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch));
|
||||
|
||||
auto scratch_scalar = scratch.scalar<Index>();
|
||||
auto indices_flat = indices.flat_inner_dims<Index>();
|
||||
auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
|
||||
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out));
|
||||
functor::SetZeroFunctor<Device, T> fill;
|
||||
fill(c->eigen_device<Device>(), out->flat<T>());
|
||||
auto output_matrix = out->template shaped<T, 2>(
|
||||
{shape.num_elements() / slice_size, slice_size});
|
||||
|
||||
Index bad_i = -1;
|
||||
switch (indices_nd) {
|
||||
#define PARAMS_CASE(IXDIM) \
|
||||
case IXDIM: { \
|
||||
Tensor* out = nullptr; \
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out)); \
|
||||
functor::SetZeroFunctor<Device, T> fill; \
|
||||
fill(c->eigen_device<Device>(), out->flat<T>()); \
|
||||
if (shape.num_elements() > 0) { \
|
||||
auto output_flat = out->flat_outer_dims<T, (IXDIM) + 1>(); \
|
||||
functor::ScatterNdFunctor<Device, T, Index, \
|
||||
scatter_nd_op::UpdateOp::ASSIGN, (IXDIM)> \
|
||||
functor; \
|
||||
bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \
|
||||
output_flat, indices_flat, updates_flat, output_flat); \
|
||||
} \
|
||||
|
||||
if (shape.num_elements() > 0) {
|
||||
switch (indices_nd) {
|
||||
#define PARAMS_CASE(IXDIM) \
|
||||
case IXDIM: { \
|
||||
typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix; \
|
||||
for (int i = 0; i < IXDIM; ++i) { \
|
||||
output_shape_prefix[i] = shape.dim_size(i); \
|
||||
} \
|
||||
functor::ScatterNdFunctor<Device, T, Index, scatter_nd_op::UpdateOp::ADD, \
|
||||
IXDIM> \
|
||||
functor; \
|
||||
bad_i = \
|
||||
functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
|
||||
output_matrix, indices_flat, updates_flat, output_matrix); \
|
||||
} break
|
||||
PARAMS_CASE(0);
|
||||
PARAMS_CASE(1);
|
||||
PARAMS_CASE(2);
|
||||
PARAMS_CASE(3);
|
||||
PARAMS_CASE(4);
|
||||
PARAMS_CASE(5);
|
||||
// TODO(simister): Re-enable this once binary size is under control.
|
||||
// PARAMS_CASE(0);
|
||||
PARAMS_CASE(1);
|
||||
PARAMS_CASE(2);
|
||||
PARAMS_CASE(3);
|
||||
PARAMS_CASE(4);
|
||||
PARAMS_CASE(5);
|
||||
#undef PARAMS_CASE
|
||||
default:
|
||||
OP_REQUIRES(c, false,
|
||||
errors::InvalidArgument(
|
||||
"Only indices.shape[-1] values between 0 and 5 "
|
||||
"are currently supported. Requested rank: ",
|
||||
indices_nd));
|
||||
default:
|
||||
OP_REQUIRES(c, false,
|
||||
errors::InvalidArgument(
|
||||
"Only indices.shape[-1] values between 1 and 5 "
|
||||
"are currently supported. Requested rank: ",
|
||||
indices_nd));
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(
|
||||
c, bad_i < 0,
|
||||
@ -236,24 +241,27 @@ class ScatterNdUpdateOp : public OpKernel {
|
||||
&indices_nd, &num_updates, &slice_size);
|
||||
if (!c->status().ok()) return;
|
||||
|
||||
Tensor scratch;
|
||||
OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch));
|
||||
|
||||
auto scratch_scalar = scratch.scalar<Index>();
|
||||
auto indices_flat = indices.flat_inner_dims<Index>();
|
||||
auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
|
||||
|
||||
auto params_matrix = params.template shaped<T, 2>(
|
||||
{params_shape.num_elements() / slice_size, slice_size});
|
||||
Index bad_i = -1;
|
||||
c->forward_ref_input_to_ref_output(0, 0);
|
||||
|
||||
switch (indices_nd) {
|
||||
#define PARAMS_CASE(IXDIM) \
|
||||
case IXDIM: { \
|
||||
auto params_flat = params.flat_outer_dims<T, (IXDIM) + 1>(); \
|
||||
functor::ScatterNdFunctor<Device, T, Index, op, IXDIM> functor; \
|
||||
bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \
|
||||
params_flat, indices_flat, updates_flat, params_flat); \
|
||||
#define PARAMS_CASE(IXDIM) \
|
||||
case IXDIM: { \
|
||||
typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix; \
|
||||
for (int i = 0; i < IXDIM; ++i) { \
|
||||
output_shape_prefix[i] = params_shape.dim_size(i); \
|
||||
} \
|
||||
functor::ScatterNdFunctor<Device, T, Index, op, IXDIM> functor; \
|
||||
bad_i = \
|
||||
functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
|
||||
params_matrix, indices_flat, updates_flat, params_matrix); \
|
||||
} break
|
||||
PARAMS_CASE(0);
|
||||
// TODO(simister): Re-enable this once binary size is under control.
|
||||
// PARAMS_CASE(0);
|
||||
PARAMS_CASE(1);
|
||||
PARAMS_CASE(2);
|
||||
PARAMS_CASE(3);
|
||||
@ -306,11 +314,13 @@ class ScatterNdUpdateOp : public OpKernel {
|
||||
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \
|
||||
scatter_nd_op::UpdateOp::ADD); \
|
||||
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \
|
||||
scatter_nd_op::UpdateOp::SUB); \
|
||||
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \
|
||||
scatter_nd_op::UpdateOp::MUL); \
|
||||
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \
|
||||
scatter_nd_op::UpdateOp::DIV);
|
||||
scatter_nd_op::UpdateOp::SUB);
|
||||
// TODO(simister): Find a way to reduce amount of templated generated code
|
||||
// to reduce build size, then re-enable these additional operations.
|
||||
// REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \
|
||||
// scatter_nd_op::UpdateOp::MUL); \
|
||||
// REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \
|
||||
// scatter_nd_op::UpdateOp::DIV);
|
||||
|
||||
#define REGISTER_SCATTER_ND(type, dev) \
|
||||
REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
|
||||
@ -329,8 +339,9 @@ class ScatterNdUpdateOp : public OpKernel {
|
||||
#define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);
|
||||
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
|
||||
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
|
||||
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU);
|
||||
// TODO(simister): Re-enable all types after binary size is under control.
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);
|
||||
|
||||
// Registers GPU kernels.
|
||||
#if GOOGLE_CUDA
|
||||
@ -356,47 +367,4 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU);
|
||||
#undef REGISTER_SCATTER_ND_KERNEL
|
||||
#undef REGISTER_SCATTER_ND_KERNEL_INDEX
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
|
||||
#define DECLARE_GPU_SPECS_OP(T, Index, op, NDIM) \
|
||||
template <> \
|
||||
Index ScatterNdFunctor<GPUDevice, T, Index, op, NDIM>::operator()( \
|
||||
OpKernelContext* c, const GPUDevice& d, \
|
||||
typename TTypes<T, IXDIM>::Tensor params, \
|
||||
typename TTypes<Index, 2>::ConstTensor indices, \
|
||||
typename TTypes<T, 2>::ConstTensor updates); \
|
||||
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op>;
|
||||
|
||||
#define DECLARE_GPU_SPECS_OPS(T, Index, op) \
|
||||
DECLARE_GPU_SPECS_OP(T, Index, op, 0); \
|
||||
DECLARE_GPU_SPECS_OP(T, Index, op, 1); \
|
||||
DECLARE_GPU_SPECS_OP(T, Index, op, 2); \
|
||||
DECLARE_GPU_SPECS_OP(T, Index, op, 3); \
|
||||
DECLARE_GPU_SPECS_OP(T, Index, op, 4); \
|
||||
DECLARE_GPU_SPECS_OP(T, Index, op, 5)
|
||||
|
||||
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
|
||||
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
|
||||
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ADD); \
|
||||
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::SUB); \
|
||||
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::MUL); \
|
||||
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::DIV);
|
||||
|
||||
#define DECLARE_GPU_SPECS(T) \
|
||||
DECLARE_GPU_SPECS_INDEX(T, int32); \
|
||||
DECLARE_GPU_SPECS_INDEX(T, int64);
|
||||
|
||||
// TODO(simister): Re-enable when GPU support is working.
|
||||
// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_SPECS_INDEX
|
||||
#undef DECLARE_GPU_SPECS_OPS
|
||||
#undef DECLARE_GPU_SPECS_OP
|
||||
|
||||
} // namespace functor
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -48,12 +48,13 @@ template <typename Device, typename T, typename Index,
|
||||
scatter_nd_op::UpdateOp op, int IXDIM>
|
||||
struct ScatterNdFunctor {
|
||||
// Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
|
||||
Index operator()(const Device& d, const Index slice_size,
|
||||
typename TTypes<Index>::Scalar Tscratch,
|
||||
typename TTypes<T, IXDIM + 1>::Tensor Tparams,
|
||||
typename TTypes<Index, 2>::ConstTensor Tindices,
|
||||
typename TTypes<T, 2>::ConstTensor Tupdates,
|
||||
typename TTypes<T, IXDIM + 1>::Tensor Toutput);
|
||||
Index operator()(
|
||||
const Device& d, const Index slice_size,
|
||||
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
|
||||
typename TTypes<T, 2>::Tensor Tparams,
|
||||
typename TTypes<Index, 2>::ConstTensor Tindices,
|
||||
typename TTypes<T, 2>::ConstTensor Tupdates,
|
||||
typename TTypes<T, 2>::Tensor Toutput);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
@ -42,147 +42,113 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
class OpKernelContext;
|
||||
|
||||
// Specialization of UpdateExecutor to CPU
|
||||
namespace generator {
|
||||
namespace update_executor {
|
||||
|
||||
template <typename T, typename Index, scatter_nd_op::UpdateOp op>
|
||||
template <typename Input, typename Update, typename Output,
|
||||
scatter_nd_op::UpdateOp OP>
|
||||
class UpdateExecutor {
|
||||
public:
|
||||
static void Update(T* input, const T* updates, T* output, Index slice_size);
|
||||
EIGEN_STRONG_INLINE static void Execute(Input value, Update update,
|
||||
Output output);
|
||||
};
|
||||
|
||||
template <typename T, typename Index>
|
||||
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::ASSIGN> {
|
||||
template <typename Input, typename Update, typename Output>
|
||||
class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::ASSIGN> {
|
||||
public:
|
||||
static void Update(T* /* unused */, const T* updates, T* output,
|
||||
Index slice_size) {
|
||||
std::copy_n(updates, slice_size, output);
|
||||
EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update,
|
||||
Output output) {
|
||||
output = update;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Index>
|
||||
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::ADD> {
|
||||
template <typename Input, typename Update, typename Output>
|
||||
class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::ADD> {
|
||||
public:
|
||||
static void Update(T* input, const T* updates, T* output, Index slice_size) {
|
||||
std::transform(input, input + slice_size, updates, output, std::plus<T>());
|
||||
EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update,
|
||||
Output output) {
|
||||
output += update;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Index>
|
||||
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::SUB> {
|
||||
template <typename Input, typename Update, typename Output>
|
||||
class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::SUB> {
|
||||
public:
|
||||
static void Update(T* input, const T* updates, T* output, Index slice_size) {
|
||||
std::transform(input, input + slice_size, updates, output, std::minus<T>());
|
||||
EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update,
|
||||
Output output) {
|
||||
output -= update;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Index>
|
||||
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::MUL> {
|
||||
template <typename Input, typename Update, typename Output>
|
||||
class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::MUL> {
|
||||
public:
|
||||
static void Update(T* input, const T* updates, T* output, Index slice_size) {
|
||||
std::transform(input, input + slice_size, updates, output,
|
||||
std::multiplies<T>());
|
||||
EIGEN_STRONG_INLINE static void Execute(Input input, Update update,
|
||||
Output output) {
|
||||
output = input * update;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Index>
|
||||
class UpdateExecutor<T, Index, scatter_nd_op::UpdateOp::DIV> {
|
||||
template <typename Input, typename Update, typename Output>
|
||||
class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::DIV> {
|
||||
public:
|
||||
static void Update(T* input, const T* updates, T* output, Index slice_size) {
|
||||
std::transform(input, input + slice_size, updates, output,
|
||||
std::divides<T>());
|
||||
EIGEN_STRONG_INLINE static void Execute(Input input, Update update,
|
||||
Output output) {
|
||||
output = input / update;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
|
||||
class ScatterNdSliceGenerator {
|
||||
public:
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ScatterNdSliceGenerator(
|
||||
const Index slice_size, typename TTypes<T, IXDIM + 1>::Tensor Tparams,
|
||||
typename TTypes<Index, 2>::ConstTensor Tindices,
|
||||
typename TTypes<T, 2>::ConstTensor Tupdates,
|
||||
typename TTypes<T, IXDIM + 1>::Tensor Toutput,
|
||||
std::atomic<Index>* error_loc)
|
||||
: slice_size_(slice_size),
|
||||
Tparams_(Tparams),
|
||||
Tindices_(Tindices),
|
||||
Tupdates_(Tupdates),
|
||||
Toutput_(Toutput),
|
||||
error_loc_(error_loc) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC bool GenerateIndices(
|
||||
const Index loc, Eigen::array<Eigen::DenseIndex, IXDIM + 1>* ix) const {
|
||||
(*ix)[IXDIM] = 0;
|
||||
bool out_of_bounds = false;
|
||||
for (int i = 0; i < IXDIM; ++i) {
|
||||
const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i));
|
||||
(*ix)[i] = ix_i;
|
||||
out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
|
||||
}
|
||||
return out_of_bounds;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32
|
||||
operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const {
|
||||
auto loc = loc_array[0];
|
||||
Eigen::array<Eigen::DenseIndex, IXDIM + 1> ix_params;
|
||||
Eigen::array<Eigen::DenseIndex, 2> ix_updates;
|
||||
ix_updates[0] = loc;
|
||||
ix_updates[1] = 0;
|
||||
const bool out_of_bounds = GenerateIndices(loc, &ix_params);
|
||||
if (TF_PREDICT_FALSE(out_of_bounds)) {
|
||||
error_loc_->store(loc);
|
||||
} else {
|
||||
UpdateExecutor<T, Index, op>::Update(&Tparams_(ix_params),
|
||||
&Tupdates_(ix_updates),
|
||||
&Toutput_(ix_params), slice_size_);
|
||||
}
|
||||
return static_cast<int32>(0); // Return something...
|
||||
}
|
||||
|
||||
protected:
|
||||
const Index slice_size_;
|
||||
mutable typename TTypes<T, IXDIM + 1>::Tensor Tparams_;
|
||||
const typename TTypes<Index, 2>::ConstTensor Tindices_;
|
||||
const typename TTypes<T, 2>::ConstTensor Tupdates_;
|
||||
mutable typename TTypes<T, IXDIM + 1>::Tensor Toutput_;
|
||||
std::atomic<Index>* error_loc_;
|
||||
};
|
||||
|
||||
} // namespace generator
|
||||
} // namespace update_executor
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Implementation of update functor for CPU.
|
||||
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
|
||||
struct ScatterNdFunctor<CPUDevice, T, Index, op, IXDIM> {
|
||||
Index operator()(const CPUDevice& d, const Index slice_size,
|
||||
typename TTypes<Index>::Scalar Tscratch,
|
||||
typename TTypes<T, IXDIM + 1>::Tensor Tparams,
|
||||
typename TTypes<Index, 2>::ConstTensor Tindices,
|
||||
typename TTypes<T, 2>::ConstTensor Tupdates,
|
||||
typename TTypes<T, IXDIM + 1>::Tensor Toutput) {
|
||||
std::atomic<Index> error_loc(-1);
|
||||
template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM>
|
||||
struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
|
||||
Index operator()(
|
||||
const CPUDevice& d, const Index slice_size,
|
||||
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
|
||||
typename TTypes<T, 2>::Tensor Tparams,
|
||||
typename TTypes<Index, 2>::ConstTensor Tindices,
|
||||
typename TTypes<T, 2>::ConstTensor Tupdates,
|
||||
typename TTypes<T, 2>::Tensor Toutput) {
|
||||
// error_loc is -1 if there's no out-of-bounds index,
|
||||
// otherwise it is the location of an OOB index in Tindices.
|
||||
Index error_loc = -1;
|
||||
|
||||
const Eigen::DenseIndex batch_size = Tindices.dimension(0);
|
||||
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||
Eigen::Tensor<Eigen::DenseIndex, 1>::Dimensions reshape_dims{{ 1 }};
|
||||
Eigen::array<Eigen::DenseIndex, 1> broadcast_dims{{ batch_size }};
|
||||
#else
|
||||
Eigen::IndexList<Eigen::type2index<1> > reshape_dims;
|
||||
Eigen::IndexList<Eigen::DenseIndex> broadcast_dims;
|
||||
broadcast_dims.set(0, batch_size);
|
||||
#endif
|
||||
|
||||
generator::ScatterNdSliceGenerator<T, Index, op, IXDIM> generator(
|
||||
slice_size, Tparams, Tindices, Tupdates, Toutput, &error_loc);
|
||||
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
|
||||
.broadcast(broadcast_dims)
|
||||
.generate(generator)
|
||||
.sum();
|
||||
Index batch_strides[IXDIM];
|
||||
for (int dim = IXDIM - 1; dim >= 0; --dim) {
|
||||
if (dim == IXDIM - 1) {
|
||||
batch_strides[dim] = 1;
|
||||
} else {
|
||||
batch_strides[dim] =
|
||||
batch_strides[dim + 1] * output_shape_prefix[dim + 1];
|
||||
}
|
||||
}
|
||||
|
||||
// error_loc() returns -1 if there's no out-of-bounds index,
|
||||
// otherwise it returns the location of an OOB index in Tindices.
|
||||
return error_loc.load();
|
||||
for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) {
|
||||
Index i = 0;
|
||||
bool out_of_bounds = false;
|
||||
for (int dim = 0; dim < IXDIM; ++dim) {
|
||||
const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim));
|
||||
out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]);
|
||||
i += ix_d * batch_strides[dim];
|
||||
}
|
||||
if (TF_PREDICT_FALSE(out_of_bounds)) {
|
||||
error_loc = loc;
|
||||
break;
|
||||
} else {
|
||||
auto input_chip = Toutput.template chip<0>(i);
|
||||
auto output_chip = input_chip.device(d);
|
||||
auto update_chip = Tupdates.template chip<0>(loc);
|
||||
update_executor::UpdateExecutor<
|
||||
decltype(input_chip), decltype(update_chip), decltype(output_chip),
|
||||
OP>::Execute(input_chip, update_chip, output_chip);
|
||||
}
|
||||
}
|
||||
|
||||
return error_loc;
|
||||
}
|
||||
};
|
||||
|
||||
@ -190,11 +156,12 @@ struct ScatterNdFunctor<CPUDevice, T, Index, op, IXDIM> {
|
||||
template Index \
|
||||
ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
|
||||
const CPUDevice& d, const Index slice_size, \
|
||||
typename TTypes<Index>::Scalar Tscratch, \
|
||||
typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::Tensor Tparams, \
|
||||
const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
|
||||
output_shape_prefix, \
|
||||
typename TTypes<T, 2>::Tensor Tparams, \
|
||||
typename TTypes<Index, 2>::ConstTensor Tindices, \
|
||||
typename TTypes<T, 2>::ConstTensor Tupdates, \
|
||||
typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::Tensor Toutput)
|
||||
typename TTypes<T, 2>::Tensor Toutput)
|
||||
|
||||
#define REGISTER_SCATTER_ND_INDEX(type, op) \
|
||||
REGISTER_SCATTER_ND_FULL(type, int32, op); \
|
||||
@ -205,9 +172,11 @@ struct ScatterNdFunctor<CPUDevice, T, Index, op, IXDIM> {
|
||||
|
||||
#define REGISTER_SCATTER_ND_MATH(type) \
|
||||
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \
|
||||
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB); \
|
||||
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MUL); \
|
||||
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::DIV);
|
||||
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB);
|
||||
// TODO(simister): Re-enable after identifying a way to reduce the binary size
|
||||
// due to too many template instantiations.
|
||||
// REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MUL); \
|
||||
// REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::DIV);
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH)
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 0
|
||||
#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
||||
// TODO(simister): Re-enable once binary size is under control.
|
||||
// #define CPU_PROVIDED_IXDIM 0
|
||||
// #include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h"
|
||||
// #undef CPU_PROVIDED_IXDIM
|
||||
|
@ -48,31 +48,32 @@ class ScatterNdUpdateOpTest : public OpsTestBase {
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ScatterNdUpdateOpTest, Simple_StringType) {
|
||||
MakeOp(DT_STRING_REF, DT_INT32);
|
||||
AddInputFromArray<string>(TensorShape({1}), {"Brain"});
|
||||
AddInputFromArray<int32>(TensorShape({1}), {0});
|
||||
AddInputFromArray<string>(TensorShape({1}), {"TensorFlow"});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
// Check the new state of the input
|
||||
Tensor params_tensor = *mutable_input(0).tensor;
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({1}));
|
||||
test::FillValues<string>(&expected, {"TensorFlow"});
|
||||
test::ExpectTensorEqual<string>(expected, params_tensor);
|
||||
}
|
||||
// TODO(simister): Re-enable this once binary size is under control.
|
||||
// TEST_F(ScatterNdUpdateOpTest, Simple_StringType) {
|
||||
// MakeOp(DT_STRING_REF, DT_INT32);
|
||||
// AddInputFromArray<string>(TensorShape({1}), {"Brain"});
|
||||
// AddInputFromArray<int32>(TensorShape({1}), {0});
|
||||
// AddInputFromArray<string>(TensorShape({1}), {"TensorFlow"});
|
||||
// TF_ASSERT_OK(RunOpKernel());
|
||||
// // Check the new state of the input
|
||||
// Tensor params_tensor = *mutable_input(0).tensor;
|
||||
// Tensor expected(allocator(), DT_STRING, TensorShape({1}));
|
||||
// test::FillValues<string>(&expected, {"TensorFlow"});
|
||||
// test::ExpectTensorEqual<string>(expected, params_tensor);
|
||||
// }
|
||||
|
||||
TEST_F(ScatterNdUpdateOpTest, Simple_BoolType) {
|
||||
MakeOp(DT_BOOL_REF, DT_INT32);
|
||||
AddInputFromArray<bool>(TensorShape({1}), {false});
|
||||
AddInputFromArray<int32>(TensorShape({1}), {0});
|
||||
AddInputFromArray<bool>(TensorShape({1}), {true});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
// Check the new state of the input
|
||||
Tensor params_tensor = *mutable_input(0).tensor;
|
||||
Tensor expected(allocator(), DT_BOOL, TensorShape({1}));
|
||||
test::FillValues<bool>(&expected, {true});
|
||||
test::ExpectTensorEqual<bool>(expected, params_tensor);
|
||||
}
|
||||
// TEST_F(ScatterNdUpdateOpTest, Simple_BoolType) {
|
||||
// MakeOp(DT_BOOL_REF, DT_INT32);
|
||||
// AddInputFromArray<bool>(TensorShape({1}), {false});
|
||||
// AddInputFromArray<int32>(TensorShape({1}), {0});
|
||||
// AddInputFromArray<bool>(TensorShape({1}), {true});
|
||||
// TF_ASSERT_OK(RunOpKernel());
|
||||
// // Check the new state of the input
|
||||
// Tensor params_tensor = *mutable_input(0).tensor;
|
||||
// Tensor expected(allocator(), DT_BOOL, TensorShape({1}));
|
||||
// test::FillValues<bool>(&expected, {true});
|
||||
// test::ExpectTensorEqual<bool>(expected, params_tensor);
|
||||
// }
|
||||
|
||||
TEST_F(ScatterNdUpdateOpTest, Simple_TwoD32) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
@ -111,6 +112,7 @@ TEST_F(ScatterNdUpdateOpTest, Simple_Two64) {
|
||||
10002, 0, 0, 0, 777, 778, 779});
|
||||
test::ExpectTensorEqual<float>(expected, params_tensor);
|
||||
}
|
||||
|
||||
/*TEST_F(ScatterNdUpdateOpTest, Simple_ZeroElements) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
|
||||
|
@ -4395,12 +4395,16 @@ REGISTER_OP("ScatterNd")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Doc(
|
||||
R"doc(Creates a new tensor by applying sparse `updates` to individual values or slices within a zero tensor of the given `shape` tensor according to indices.
|
||||
This operator is the inverse of the [tf.gather_nd](#gather_nd) operator which extracts values or slices from a given tensor.
|
||||
R"doc(Creates a new tensor by applying sparse `updates` to individual
|
||||
values or slices within a zero tensor of the given `shape` tensor according to
|
||||
indices. This operator is the inverse of the [tf.gather_nd](#gather_nd)
|
||||
operator which extracts values or slices from a given tensor.
|
||||
|
||||
TODO(simister): Add a link to Variable.__getitem__ documentation on slice syntax.
|
||||
TODO(simister): Add a link to Variable.__getitem__ documentation on slice
|
||||
syntax.
|
||||
|
||||
`shape` is a `TensorShape` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
`shape` is a `TensorShape` with rank `P` and `indices` is a `Tensor` of rank
|
||||
`Q`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `shape`.
|
||||
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
@ -4415,7 +4419,9 @@ dimension of `shape`.
|
||||
[d_0, ..., d_{Q-2}, shape[K], ..., shape[P-1]].
|
||||
```
|
||||
|
||||
The simplest form of scatter is to insert individual elements in a tensor by index. For example, say we want to insert 4 scattered elements in a rank-1 tensor with 8 elements.
|
||||
The simplest form of scatter is to insert individual elements in a tensor by
|
||||
index. For example, say we want to insert 4 scattered elements in a rank-1
|
||||
tensor with 8 elements.
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="../../images/ScatterNd1.png" alt>
|
||||
@ -4434,7 +4440,9 @@ The resulting tensor would look like this:
|
||||
|
||||
[0, 11, 0, 10, 9, 0, 0, 12]
|
||||
|
||||
We can also, insert entire slices of a higher rank tensor all at once. For example, if we wanted to insert two slices in the first dimension of a rank-3 tensor with two matrices of new values.
|
||||
We can also, insert entire slices of a higher rank tensor all at once. For
|
||||
example, if we wanted to insert two slices in the first dimension of a
|
||||
rank-3 tensor with two matrices of new values.
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="../../images/ScatterNd2.png" alt>
|
||||
@ -4459,10 +4467,14 @@ The resulting tensor would look like this:
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
|
||||
|
||||
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as tensor. A tensor of updated values to store in ref.
|
||||
indices: A Tensor. Must be one of the following types: int32, int64.
|
||||
A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as tensor. A tensor of updated values
|
||||
to store in ref.
|
||||
shape: A vector. The shape of the resulting tensor.
|
||||
output: A new tensor with the given shape and updates applied according to the indices.)doc");
|
||||
output: A new tensor with the given shape and updates applied according
|
||||
to the indices.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("FakeQuantWithMinMaxArgs")
|
||||
.Attr("min: float = -6.0")
|
||||
|
@ -24864,126 +24864,6 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdDiv"
|
||||
input_arg {
|
||||
name: "ref"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output_ref"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdMul"
|
||||
input_arg {
|
||||
name: "ref"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output_ref"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdSub"
|
||||
input_arg {
|
||||
|
@ -15579,140 +15579,6 @@ op {
|
||||
summary: "Applies sparse addition between `updates` and individual values or slices within a given variable according to `indices`."
|
||||
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to add 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that addition would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n add = tf.scatter_nd_add(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(add)\n\nThe resulting update to ref would look like this:\n\n [1, 13, 3, 14, 14, 6, 7, 20]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdDiv"
|
||||
input_arg {
|
||||
name: "ref"
|
||||
description: "A mutable Tensor. Should be from a Variable node."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
description: "A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output_ref"
|
||||
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
|
||||
}
|
||||
summary: "Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`."
|
||||
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to divide a rank-1 tensor with 8 elements by 4 scattered elements. In Python, that division would look like this:\n\n ref = tf.Variable([10, 20, 30, 40, 50, 60, 70, 80])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([2, 3, 4, 5])\n sub = tf.scatter_nd_div(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(sub)\n\nThe resulting update to ref would look like this:\n\n [10, 5, 30, 13, 25, 60, 70, 16]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdMul"
|
||||
input_arg {
|
||||
name: "ref"
|
||||
description: "A mutable Tensor. Should be from a Variable node."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref."
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "updates"
|
||||
description: "A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output_ref"
|
||||
description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done."
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention."
|
||||
}
|
||||
summary: "Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`."
|
||||
description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to multiply 4 scattered elements with a rank-1 tensor with 8 elements. In Python, that multiplication would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n sub = tf.scatter_nd_mul(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(sub)\n\nThe resulting update to ref would look like this:\n\n [1, 22, 3, 40, 45, 6, 7, 96]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices."
|
||||
}
|
||||
op {
|
||||
name: "ScatterNdSub"
|
||||
input_arg {
|
||||
|
@ -453,8 +453,9 @@ REGISTER_OP("ScatterNdUpdate")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = true")
|
||||
.Doc(
|
||||
R"doc(Applies sparse `updates` to individual values or slices within a given variable according to `indices`.
|
||||
.Doc(R"doc(
|
||||
Applies sparse `updates` to individual values or slices within a given
|
||||
variable according to `indices`.
|
||||
|
||||
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
@ -471,7 +472,8 @@ dimension of `ref`.
|
||||
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
```
|
||||
|
||||
For example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this:
|
||||
For example, say we want to update 4 scattered elements to a rank-1 tensor to
|
||||
8 elements. In Python, that update would look like this:
|
||||
|
||||
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
indices = tf.constant([[4], [3], [1] ,[7]])
|
||||
@ -484,13 +486,20 @@ The resulting update to ref would look like this:
|
||||
|
||||
[1, 11, 3, 10, 9, 6, 7, 12]
|
||||
|
||||
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to
|
||||
slices.
|
||||
|
||||
ref: A mutable Tensor. Should be from a Variable node.
|
||||
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as ref. A tensor of updated values to add to ref.
|
||||
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||
indices: A Tensor. Must be one of the following types: int32, int64.
|
||||
A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as ref. A tensor of updated
|
||||
values to add to ref.
|
||||
use_locking: An optional bool. Defaults to True. If True, the assignment will
|
||||
be protected by a lock; otherwise the behavior is undefined,
|
||||
but may exhibit less contention.
|
||||
output_ref: Same as ref. Returned as a convenience for operations that want to
|
||||
use the updated values after the update is done.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("ScatterNdAdd")
|
||||
.Input("ref: Ref(T)")
|
||||
@ -500,8 +509,9 @@ REGISTER_OP("ScatterNdAdd")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Doc(
|
||||
R"doc(Applies sparse addition between `updates` and individual values or slices within a given variable according to `indices`.
|
||||
.Doc(R"doc(
|
||||
Applies sparse addition between `updates` and individual values or slices
|
||||
within a given variable according to `indices`.
|
||||
|
||||
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
@ -518,7 +528,8 @@ dimension of `ref`.
|
||||
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
```
|
||||
|
||||
For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that addition would look like this:
|
||||
For example, say we want to add 4 scattered elements to a rank-1 tensor to 8
|
||||
elements. In Python, that addition would look like this:
|
||||
|
||||
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
indices = tf.constant([[4], [3], [1], [7]])
|
||||
@ -531,13 +542,20 @@ The resulting update to ref would look like this:
|
||||
|
||||
[1, 13, 3, 14, 14, 6, 7, 20]
|
||||
|
||||
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to
|
||||
slices.
|
||||
|
||||
ref: A mutable Tensor. Should be from a Variable node.
|
||||
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as ref. A tensor of updated values to add to ref.
|
||||
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||
indices: A Tensor. Must be one of the following types: int32, int64.
|
||||
A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as ref. A tensor of updated values
|
||||
to add to ref.
|
||||
use_locking: An optional bool. Defaults to True. If True, the assignment will
|
||||
be protected by a lock; otherwise the behavior is undefined,
|
||||
but may exhibit less contention.
|
||||
output_ref: Same as ref. Returned as a convenience for operations that want
|
||||
to use the updated values after the update is done.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("ScatterNdSub")
|
||||
.Input("ref: Ref(T)")
|
||||
@ -547,8 +565,9 @@ REGISTER_OP("ScatterNdSub")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Doc(
|
||||
R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
|
||||
.Doc(R"doc(
|
||||
Applies sparse subtraction between `updates` and individual values or slices
|
||||
within a given variable according to `indices`.
|
||||
|
||||
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
@ -565,7 +584,8 @@ dimension of `ref`.
|
||||
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
```
|
||||
|
||||
For example, say we want to subtract 4 scattered elements from a rank-1 tensor with 8 elements. In Python, that subtraction would look like this:
|
||||
For example, say we want to subtract 4 scattered elements from a rank-1 tensor
|
||||
with 8 elements. In Python, that subtraction would look like this:
|
||||
|
||||
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
indices = tf.constant([[4], [3], [1], [7]])
|
||||
@ -578,107 +598,133 @@ The resulting update to ref would look like this:
|
||||
|
||||
[1, -9, 3, -6, -4, 6, 7, -4]
|
||||
|
||||
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to
|
||||
slices.
|
||||
|
||||
ref: A mutable Tensor. Should be from a Variable node.
|
||||
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
|
||||
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||
indices: A Tensor. Must be one of the following types: int32, int64.
|
||||
A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as ref. A tensor of updated values
|
||||
to subtract from ref.
|
||||
use_locking: An optional bool. Defaults to True. If True, the assignment will
|
||||
be protected by a lock; otherwise the behavior is undefined,
|
||||
but may exhibit less contention.
|
||||
output_ref: Same as ref. Returned as a convenience for operations that want
|
||||
to use the updated values after the update is done.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("ScatterNdMul")
|
||||
.Input("ref: Ref(T)")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Output("output_ref: Ref(T)")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Doc(
|
||||
R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
|
||||
// TODO(simister): Re-enable once these additional ops do not dramatically
|
||||
// increase binary size.
|
||||
|
||||
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
// REGISTER_OP("ScatterNdMul")
|
||||
// .Input("ref: Ref(T)")
|
||||
// .Input("indices: Tindices")
|
||||
// .Input("updates: T")
|
||||
// .Output("output_ref: Ref(T)")
|
||||
// .Attr("T: numbertype")
|
||||
// .Attr("Tindices: {int32, int64}")
|
||||
// .Attr("use_locking: bool = false")
|
||||
// .Doc(
|
||||
// R"doc(Applies sparse subtraction between `updates` and individual
|
||||
// values or slices within a given variable according to `indices`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `ref`.
|
||||
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||
dimension of `ref`.
|
||||
// `indices` must be integer tensor, containing indices into `ref`.
|
||||
// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
|
||||
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||
// The innermost dimension of `indices` (with length `K`) corresponds to
|
||||
// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||
// dimension of `ref`.
|
||||
|
||||
```
|
||||
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
```
|
||||
// `updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||
|
||||
For example, say we want to multiply 4 scattered elements with a rank-1 tensor with 8 elements. In Python, that multiplication would look like this:
|
||||
// ```
|
||||
// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
// ```
|
||||
|
||||
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
indices = tf.constant([[4], [3], [1], [7]])
|
||||
updates = tf.constant([9, 10, 11, 12])
|
||||
sub = tf.scatter_nd_mul(ref, indices, updates)
|
||||
with tf.Session() as sess:
|
||||
print sess.run(sub)
|
||||
// For example, say we want to multiply 4 scattered elements with a rank-1
|
||||
// tensor with 8 elements. In Python, that multiplication would look like this:
|
||||
|
||||
The resulting update to ref would look like this:
|
||||
// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
// indices = tf.constant([[4], [3], [1], [7]])
|
||||
// updates = tf.constant([9, 10, 11, 12])
|
||||
// sub = tf.scatter_nd_mul(ref, indices, updates)
|
||||
// with tf.Session() as sess:
|
||||
// print sess.run(sub)
|
||||
|
||||
[1, 22, 3, 40, 45, 6, 7, 96]
|
||||
// The resulting update to ref would look like this:
|
||||
|
||||
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||
// [1, 22, 3, 40, 45, 6, 7, 96]
|
||||
|
||||
ref: A mutable Tensor. Should be from a Variable node.
|
||||
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
|
||||
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||
// See [tf.scatter_nd](#scatter_nd) for more details about how to make updates
|
||||
// to slices.
|
||||
|
||||
REGISTER_OP("ScatterNdDiv")
|
||||
.Input("ref: Ref(T)")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Output("output_ref: Ref(T)")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Doc(
|
||||
R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`.
|
||||
// ref: A mutable Tensor. Should be from a Variable node.
|
||||
// indices: A Tensor. Must be one of the following types: int32, int64. A tensor
|
||||
// of indices into ref.
|
||||
// updates: A Tensor. Must have the same type as ref. A tensor of updated values
|
||||
// to subtract from ref.
|
||||
// use_locking: An optional bool. Defaults to True. If True, the assignment will
|
||||
// be protected by a lock; otherwise the behavior is undefined, but may exhibit
|
||||
// less contention.
|
||||
// output_ref: Same as ref. Returned as a convenience for operations that want
|
||||
// to use the updated values after the update is done.)doc");
|
||||
|
||||
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
// REGISTER_OP("ScatterNdDiv")
|
||||
// .Input("ref: Ref(T)")
|
||||
// .Input("indices: Tindices")
|
||||
// .Input("updates: T")
|
||||
// .Output("output_ref: Ref(T)")
|
||||
// .Attr("T: numbertype")
|
||||
// .Attr("Tindices: {int32, int64}")
|
||||
// .Attr("use_locking: bool = false")
|
||||
// .Doc(
|
||||
// R"doc(Applies sparse subtraction between `updates` and individual
|
||||
// values or slices within a given variable according to `indices`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `ref`.
|
||||
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||
dimension of `ref`.
|
||||
// `indices` must be integer tensor, containing indices into `ref`.
|
||||
// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
|
||||
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||
// The innermost dimension of `indices` (with length `K`) corresponds to
|
||||
// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||
// dimension of `ref`.
|
||||
|
||||
```
|
||||
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
```
|
||||
// `updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||
|
||||
For example, say we want to divide a rank-1 tensor with 8 elements by 4 scattered elements. In Python, that division would look like this:
|
||||
// ```
|
||||
// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
// ```
|
||||
|
||||
ref = tf.Variable([10, 20, 30, 40, 50, 60, 70, 80])
|
||||
indices = tf.constant([[4], [3], [1], [7]])
|
||||
updates = tf.constant([2, 3, 4, 5])
|
||||
sub = tf.scatter_nd_div(ref, indices, updates)
|
||||
with tf.Session() as sess:
|
||||
print sess.run(sub)
|
||||
// For example, say we want to divide a rank-1 tensor with 8 elements by 4
|
||||
// scattered elements. In Python, that division would look like this:
|
||||
|
||||
The resulting update to ref would look like this:
|
||||
// ref = tf.Variable([10, 20, 30, 40, 50, 60, 70, 80])
|
||||
// indices = tf.constant([[4], [3], [1], [7]])
|
||||
// updates = tf.constant([2, 3, 4, 5])
|
||||
// sub = tf.scatter_nd_div(ref, indices, updates)
|
||||
// with tf.Session() as sess:
|
||||
// print sess.run(sub)
|
||||
|
||||
[10, 5, 30, 13, 25, 60, 70, 16]
|
||||
// The resulting update to ref would look like this:
|
||||
|
||||
See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices.
|
||||
// [10, 5, 30, 13, 25, 60, 70, 16]
|
||||
|
||||
ref: A mutable Tensor. Should be from a Variable node.
|
||||
indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref.
|
||||
updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref.
|
||||
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc");
|
||||
// See [tf.scatter_nd](#scatter_nd) for more details about how to make updates
|
||||
// to slices.
|
||||
|
||||
// ref: A mutable Tensor. Should be from a Variable node.
|
||||
// indices: A Tensor. Must be one of the following types: int32, int64. A tensor
|
||||
// of indices into ref.
|
||||
// updates: A Tensor. Must have the same type as ref. A tensor of updated values
|
||||
// to subtract from ref.
|
||||
// use_locking: An optional bool. Defaults to True. If True, the assignment will
|
||||
// be protected by a lock; otherwise the behavior is undefined, but may exhibit
|
||||
// less contention.
|
||||
// output_ref: Same as ref. Returned as a convenience for operations that want
|
||||
// to use the updated values after the update is done.)doc");
|
||||
|
||||
REGISTER_OP("CountUpTo")
|
||||
.Input("ref: Ref(T)")
|
||||
|
@ -78,7 +78,7 @@ def _NumpyDiv(ref, indices, updates):
|
||||
return _NumpyScatterNd(ref, indices, updates, lambda p, u: p / u)
|
||||
|
||||
|
||||
class ScatterTest(tf.test.TestCase):
|
||||
class ScatterNdTest(tf.test.TestCase):
|
||||
|
||||
def _VariableRankTest(self,
|
||||
np_scatter,
|
||||
@ -145,11 +145,13 @@ class ScatterTest(tf.test.TestCase):
|
||||
def testVariableRankSub(self):
|
||||
self._VariableRankTests(_NumpySub, tf.scatter_nd_sub)
|
||||
|
||||
def testVariableRankMul(self):
|
||||
self._VariableRankTests(_NumpyMul, tf.scatter_nd_mul)
|
||||
# TODO(simister): Re-enable once binary size increase due to
|
||||
# scatter_nd ops is under control.
|
||||
# def testVariableRankMul(self):
|
||||
# self._VariableRankTests(_NumpyMul, tf.scatter_nd_mul)
|
||||
|
||||
def testVariableRankDiv(self):
|
||||
self._VariableRankTests(_NumpyDiv, tf.scatter_nd_div)
|
||||
# def testVariableRankDiv(self):
|
||||
# self._VariableRankTests(_NumpyDiv, tf.scatter_nd_div)
|
||||
|
||||
def _ScatterRepeatIndicesTest(self, np_scatter, tf_scatter):
|
||||
for vtype in (np.float32, np.float64):
|
||||
@ -167,25 +169,29 @@ class ScatterTest(tf.test.TestCase):
|
||||
"""This tests scatter_add using indices that repeat."""
|
||||
self._ScatterRepeatIndicesTest(_NumpyAdd, tf.scatter_nd_add)
|
||||
self._ScatterRepeatIndicesTest(_NumpySub, tf.scatter_nd_sub)
|
||||
self._ScatterRepeatIndicesTest(_NumpyMul, tf.scatter_nd_mul)
|
||||
self._ScatterRepeatIndicesTest(_NumpyDiv, tf.scatter_nd_div)
|
||||
# TODO(simister): Re-enable once binary size increase due to
|
||||
# extra templating is back under control.
|
||||
# self._ScatterRepeatIndicesTest(_NumpyMul, tf.scatter_nd_mul)
|
||||
# self._ScatterRepeatIndicesTest(_NumpyDiv, tf.scatter_nd_div)
|
||||
|
||||
def testBooleanScatterUpdate(self):
|
||||
with self.test_session(use_gpu=False) as session:
|
||||
var = tf.Variable([True, False])
|
||||
update0 = tf.scatter_nd_update(var, [[1]], [True])
|
||||
update1 = tf.scatter_nd_update(
|
||||
var, tf.constant(
|
||||
[[0]], dtype=tf.int64), [False])
|
||||
var.initializer.run()
|
||||
|
||||
session.run([update0, update1])
|
||||
|
||||
self.assertAllEqual([False, True], var.eval())
|
||||
# TODO(simister): Re-enable once binary size increase due to
|
||||
# extra templating is back under control and this op is re-enabled
|
||||
# def testBooleanScatterUpdate(self):
|
||||
# with self.test_session(use_gpu=False) as session:
|
||||
# var = tf.Variable([True, False])
|
||||
# update0 = tf.scatter_nd_update(var, [[1]], [True])
|
||||
# update1 = tf.scatter_nd_update(
|
||||
# var, tf.constant(
|
||||
# [[0]], dtype=tf.int64), [False])
|
||||
# var.initializer.run()
|
||||
# session.run([update0, update1])
|
||||
# self.assertAllEqual([False, True], var.eval())
|
||||
|
||||
def testScatterOutOfRangeCpu(self):
|
||||
for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_mul,
|
||||
tf.scatter_nd_div, tf.scatter_nd_update):
|
||||
# TODO(simister): Re-enable once binary size increase due to
|
||||
# scatter_nd ops is under control.
|
||||
# tf.scatter_nd_mul, tf.scatter_nd_div,
|
||||
for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_update):
|
||||
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
|
||||
updates = np.array([-3, -4, -5]).astype(np.float32)
|
||||
with self.test_session(use_gpu=False):
|
||||
@ -355,8 +361,10 @@ class ScatterTest(tf.test.TestCase):
|
||||
def _disabledTestScatterOutOfRangeGpu(self):
|
||||
if not tf.test.IsBuiltWithCuda():
|
||||
return
|
||||
for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_mul,
|
||||
tf.scatter_nd_div, tf.scatter_nd_update):
|
||||
# TODO(simister): Re-enable once binary size increase due to
|
||||
# scatter_nd ops is under control.
|
||||
# tf.scatter_nd_mul, tf.scatter_nd_div,
|
||||
for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_update):
|
||||
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
|
||||
updates = np.array([-3, -4, -5]).astype(np.float32)
|
||||
# With GPU, the code ignores indices that are out of range.
|
||||
@ -375,6 +383,14 @@ class ScatterTest(tf.test.TestCase):
|
||||
indices = np.array([2, 0, 6])
|
||||
op(ref, indices, updates).eval()
|
||||
|
||||
def testScatterNdRepatedIndicesAdd(self):
|
||||
indices = tf.zeros([100000, 1], tf.int32)
|
||||
values = np.random.randn(100000)
|
||||
shape = [1]
|
||||
with self.test_session():
|
||||
val = tf.scatter_nd(indices, values, shape).eval()
|
||||
self.assertAllClose([np.sum(values)], val)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
@ -69,8 +69,10 @@ from tensorflow.python.ops.state_ops import scatter_sub
|
||||
from tensorflow.python.ops.state_ops import scatter_update
|
||||
from tensorflow.python.ops.state_ops import scatter_nd_add
|
||||
from tensorflow.python.ops.state_ops import scatter_nd_sub
|
||||
from tensorflow.python.ops.state_ops import scatter_nd_mul
|
||||
from tensorflow.python.ops.state_ops import scatter_nd_div
|
||||
# TODO(simister): Re-enable once binary size increase due to scatter_nd
|
||||
# ops is under control.
|
||||
# from tensorflow.python.ops.state_ops import scatter_nd_mul
|
||||
# from tensorflow.python.ops.state_ops import scatter_nd_div
|
||||
from tensorflow.python.ops.state_ops import scatter_nd_update
|
||||
from tensorflow.python.ops.string_ops import *
|
||||
from tensorflow.python.ops.template import *
|
||||
|
@ -98,8 +98,6 @@ automatically by the optimizers in most cases.
|
||||
@@scatter_nd_update
|
||||
@@scatter_nd_add
|
||||
@@scatter_nd_sub
|
||||
@@scatter_nd_mul
|
||||
@@scatter_nd_div
|
||||
@@sparse_mask
|
||||
@@IndexedSlices
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user