Add Registration for non-stateful scatter_nd_min and scatter_nd_max.

Manual import of #26923. Closes #26923. Fixes #20402

PiperOrigin-RevId: 314561995
Change-Id: Ibe1629d5ef769713d9058a6a72dc1421aa252ec4
This commit is contained in:
A. Unique TensorFlower 2020-06-03 10:54:24 -07:00 committed by TensorFlower Gardener
parent 0836e5e4ea
commit 23384f2ef1
24 changed files with 684 additions and 27 deletions

View File

@ -0,0 +1,31 @@
op {
graph_op_name: "ResourceScatterNdMax"
in_arg {
name: "ref"
description: <<END
A resource handle. Must be from a VarHandleOp.
END
}
in_arg {
name: "indices"
description: <<END
A Tensor. Must be one of the following types: int32, int64.
A tensor of indices into ref.
END
}
in_arg {
name: "updates"
description: <<END
A Tensor. Must have the same type as ref. A tensor of
values whose element wise max is taken with ref
END
}
attr {
name: "use_locking"
description: <<END
An optional bool. Defaults to True. If True, the assignment will
be protected by a lock; otherwise the behavior is undefined,
but may exhibit less contention.
END
}
}

View File

@ -0,0 +1,31 @@
op {
graph_op_name: "ResourceScatterNdMin"
in_arg {
name: "ref"
description: <<END
A resource handle. Must be from a VarHandleOp.
END
}
in_arg {
name: "indices"
description: <<END
A Tensor. Must be one of the following types: int32, int64.
A tensor of indices into ref.
END
}
in_arg {
name: "updates"
description: <<END
A Tensor. Must have the same type as ref. A tensor of
values whose element wise min is taken with ref.
END
}
attr {
name: "use_locking"
description: <<END
An optional bool. Defaults to True. If True, the assignment will
be protected by a lock; otherwise the behavior is undefined,
but may exhibit less contention.
END
}
}

View File

@ -0,0 +1,40 @@
op {
graph_op_name: "ScatterNdMax"
visibility: HIDDEN
in_arg {
name: "ref"
description: <<END
A mutable Tensor. Should be from a Variable node.
END
}
in_arg {
name: "indices"
description: <<END
A Tensor. Must be one of the following types: int32, int64.
A tensor of indices into ref.
END
}
in_arg {
name: "updates"
description: <<END
A Tensor. Must have the same type as ref. A tensor of updated values
to add to ref.
END
}
out_arg {
name: "output_ref"
description: <<END
Same as ref. Returned as a convenience for operations that want
to use the updated values after the update is done.
END
}
attr {
name: "use_locking"
description: <<END
An optional bool. Defaults to True. If True, the assignment will
be protected by a lock; otherwise the behavior is undefined,
but may exhibit less contention.
END
}
summary: "Computes element-wise maximum."
}

View File

@ -0,0 +1,40 @@
op {
graph_op_name: "ScatterNdMin"
visibility: HIDDEN
in_arg {
name: "ref"
description: <<END
A mutable Tensor. Should be from a Variable node.
END
}
in_arg {
name: "indices"
description: <<END
A Tensor. Must be one of the following types: int32, int64.
A tensor of indices into ref.
END
}
in_arg {
name: "updates"
description: <<END
A Tensor. Must have the same type as ref. A tensor of updated values
to add to ref.
END
}
out_arg {
name: "output_ref"
description: <<END
Same as ref. Returned as a convenience for operations that want
to use the updated values after the update is done.
END
}
attr {
name: "use_locking"
description: <<END
An optional bool. Defaults to True. If True, the assignment will
be protected by a lock; otherwise the behavior is undefined,
but may exhibit less contention.
END
}
summary: "Computes element-wise minimum."
}

View File

@ -0,0 +1,27 @@
op {
graph_op_name: "TensorScatterMax"
in_arg {
name: "tensor"
description: <<END
Tensor to update.
END
}
in_arg {
name: "indices"
description: <<END
Index tensor.
END
}
in_arg {
name: "updates"
description: <<END
Updates to scatter into output.
END
}
out_arg {
name: "output"
description: <<END
A new tensor copied from tensor whose values are element-wise maximum between tensor and updates according to the indices.
END
}
}

View File

@ -0,0 +1,27 @@
op {
graph_op_name: "TensorScatterMin"
in_arg {
name: "tensor"
description: <<END
Tensor to update.
END
}
in_arg {
name: "indices"
description: <<END
Index tensor.
END
}
in_arg {
name: "updates"
description: <<END
Updates to scatter into output.
END
}
out_arg {
name: "output"
description: <<END
A new tensor copied from tensor whose values are element-wise minimum between tensor and updates according to the indices.
END
}
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "ResourceScatterNdMax"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "ResourceScatterNdMin"
visibility: HIDDEN
}

View File

@ -0,0 +1,6 @@
op {
graph_op_name: "TensorScatterMax"
endpoint {
name: "tensor_scatter_nd_max"
}
}

View File

@ -0,0 +1,6 @@
op {
graph_op_name: "TensorScatterMin"
endpoint {
name: "tensor_scatter_nd_min"
}
}

View File

@ -368,6 +368,16 @@ class ScatterNdUpdateOp : public OpKernel {
REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \ REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
type, dev, "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN); type, dev, "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN);
#define REGISTER_SCATTER_ND_MIN_MAX(type, dev) \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMax", \
scatter_nd_op::UpdateOp::MAX); \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMin", \
scatter_nd_op::UpdateOp::MIN); \
REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
type, dev, "ResourceScatterNdMin", scatter_nd_op::UpdateOp::MIN); \
REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
type, dev, "ResourceScatterNdMax", scatter_nd_op::UpdateOp::MAX);
// Registers CPU kernels. // Registers CPU kernels.
#define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \ #define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \
REGISTER_SCATTER_ND_ADD_SUB(type, CPU); REGISTER_SCATTER_ND_ADD_SUB(type, CPU);
@ -375,6 +385,9 @@ class ScatterNdUpdateOp : public OpKernel {
#define REGISTER_SCATTER_ND_UPDATE_CPU(type) \ #define REGISTER_SCATTER_ND_UPDATE_CPU(type) \
REGISTER_SCATTER_ND_UPDATE(type, CPU); REGISTER_SCATTER_ND_UPDATE(type, CPU);
#define REGISTER_SCATTER_ND_MIN_MAX_CPU(type) \
REGISTER_SCATTER_ND_MIN_MAX(type, CPU);
#define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU); #define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);
#define REGISTER_SCATTER_ND_GPU(type) REGISTER_SCATTER_ND(type, GPU); #define REGISTER_SCATTER_ND_GPU(type) REGISTER_SCATTER_ND(type, GPU);
@ -386,6 +399,7 @@ TF_CALL_tstring(REGISTER_SCATTER_ND_UPDATE_CPU);
TF_CALL_bool(REGISTER_SCATTER_ND_ADD_SUB_CPU); TF_CALL_bool(REGISTER_SCATTER_ND_ADD_SUB_CPU);
TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_CPU); TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_CPU);
TF_CALL_bool(REGISTER_SCATTER_ND_CPU); TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX_CPU);
#define REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, index_type, \ #define REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, index_type, \
dev) \ dev) \
@ -412,6 +426,22 @@ TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
TensorScatterOp<dev##Device, type, index_type, \ TensorScatterOp<dev##Device, type, index_type, \
scatter_nd_op::UpdateOp::SUB>) scatter_nd_op::UpdateOp::SUB>)
#define REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, index_type, dev) \
REGISTER_KERNEL_BUILDER(Name("TensorScatterMin") \
.Device(DEVICE_##dev) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tindices"), \
TensorScatterOp<dev##Device, type, index_type, \
scatter_nd_op::UpdateOp::MIN>)
#define REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, index_type, dev) \
REGISTER_KERNEL_BUILDER(Name("TensorScatterMax") \
.Device(DEVICE_##dev) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tindices"), \
TensorScatterOp<dev##Device, type, index_type, \
scatter_nd_op::UpdateOp::MAX>)
#define REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type) \ #define REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type) \
REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, CPU); \ REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, CPU); \
REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64, CPU); REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64, CPU);
@ -424,6 +454,14 @@ TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, CPU); \ REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, CPU); \
REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64, CPU); REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64, CPU);
#define REGISTER_SCATTER_ND_TENSOR_MIN_CPU(type) \
REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int32, CPU); \
REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int64, CPU);
#define REGISTER_SCATTER_ND_TENSOR_MAX_CPU(type) \
REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int32, CPU); \
REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int64, CPU);
#define REGISTER_SCATTER_ND_TENSOR_CPU(type) \ #define REGISTER_SCATTER_ND_TENSOR_CPU(type) \
REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type); \ REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type); \
REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type); \ REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type); \
@ -431,6 +469,9 @@ TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
// Register TensorScatterUpdate/Add/Sub for all number types. // Register TensorScatterUpdate/Add/Sub for all number types.
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_CPU); TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_CPU);
// Register min/max operations only for Real number types
TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_MIN_CPU);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_MAX_CPU);
// Register only TensorScatterUpdate for string/bool types as well. // Register only TensorScatterUpdate for string/bool types as well.
TF_CALL_tstring(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU); TF_CALL_tstring(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU); TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
@ -446,6 +487,9 @@ TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
#define REGISTER_SCATTER_ND_UPDATE_GPU(type) \ #define REGISTER_SCATTER_ND_UPDATE_GPU(type) \
REGISTER_SCATTER_ND_UPDATE(type, GPU); REGISTER_SCATTER_ND_UPDATE(type, GPU);
#define REGISTER_SCATTER_ND_MIN_MAX_GPU(type) \
REGISTER_SCATTER_ND_MIN_MAX(type, GPU);
#define REGISTER_SCATTER_ND_ALL_GPU(type) \ #define REGISTER_SCATTER_ND_ALL_GPU(type) \
REGISTER_SCATTER_ND_ADD_SUB_GPU(type); \ REGISTER_SCATTER_ND_ADD_SUB_GPU(type); \
REGISTER_SCATTER_ND_UPDATE_GPU(type); \ REGISTER_SCATTER_ND_UPDATE_GPU(type); \
@ -453,8 +497,11 @@ TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
// TODO(b/155931747): Use HostMemory for int32 // TODO(b/155931747): Use HostMemory for int32
TF_CALL_int32(REGISTER_SCATTER_ND_ALL_GPU); TF_CALL_int32(REGISTER_SCATTER_ND_ALL_GPU);
TF_CALL_int32(REGISTER_SCATTER_ND_MIN_MAX_GPU);
TF_CALL_int64(REGISTER_SCATTER_ND_ALL_GPU); TF_CALL_int64(REGISTER_SCATTER_ND_ALL_GPU);
TF_CALL_int64(REGISTER_SCATTER_ND_MIN_MAX_GPU);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_ALL_GPU); TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_ALL_GPU);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX_GPU);
TF_CALL_complex64(REGISTER_SCATTER_ND_ALL_GPU); TF_CALL_complex64(REGISTER_SCATTER_ND_ALL_GPU);
TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU); TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU);
@ -467,12 +514,19 @@ TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU);
#define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \ #define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \
REGISTER_SCATTER_ND_UPDATE(type, SYCL); REGISTER_SCATTER_ND_UPDATE(type, SYCL);
#define REGISTER_SCATTER_ND_MIN_MAX_GPU(type) \
REGISTER_SCATTER_ND_MIN_MAX(type, SYCL);
TF_CALL_int32(REGISTER_SCATTER_ND_ADD_SUB_SYCL); TF_CALL_int32(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
TF_CALL_int32(REGISTER_SCATTER_ND_UPDATE_SYCL); TF_CALL_int32(REGISTER_SCATTER_ND_UPDATE_SYCL);
TF_CALL_int32(REGISTER_SCATTER_ND_MIN_MAX_SYCL);
TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_SYCL); TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_SYCL);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_MIN_MAX_SYCL);
#undef REGISTER_SCATTER_ND_ADD_SUB_SYCL #undef REGISTER_SCATTER_ND_ADD_SUB_SYCL
#undef REGISTER_SCATTER_ND_MIN_MAX_SYCL
#undef REGISTER_SCATTER_ND_UPDATE_SYCL #undef REGISTER_SCATTER_ND_UPDATE_SYCL
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL
@ -488,13 +542,27 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL);
REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, GPU); \ REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, GPU); \
REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64, GPU); REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64, GPU);
#define REGISTER_SCATTER_ND_TENSOR_MIN_GPU(type) \
REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int32, GPU); \
REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int64, GPU);
#define REGISTER_SCATTER_ND_TENSOR_MAX_GPU(type) \
REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int32, GPU); \
REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int64, GPU);
#define REGISTER_SCATTER_ND_TENSOR_GPU(type) \ #define REGISTER_SCATTER_ND_TENSOR_GPU(type) \
REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type); \ REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type); \
REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type); \ REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type); \
REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type); REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type);
#define REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX(type) \
REGISTER_SCATTER_ND_TENSOR_MIN_GPU(type); \
REGISTER_SCATTER_ND_TENSOR_MAX_GPU(type);
TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU); TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU);
TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX);
TF_CALL_complex64(REGISTER_SCATTER_ND_TENSOR_GPU); TF_CALL_complex64(REGISTER_SCATTER_ND_TENSOR_GPU);
TF_CALL_complex128(REGISTER_SCATTER_ND_TENSOR_GPU); TF_CALL_complex128(REGISTER_SCATTER_ND_TENSOR_GPU);
@ -502,6 +570,9 @@ TF_CALL_complex128(REGISTER_SCATTER_ND_TENSOR_GPU);
#undef REGISTER_SCATTER_ND_ADD_SUB #undef REGISTER_SCATTER_ND_ADD_SUB
#undef REGISTER_SCATTER_ND_ADD_SUB_CPU #undef REGISTER_SCATTER_ND_ADD_SUB_CPU
#undef REGISTER_SCATTER_ND_ADD_SUB_GPU #undef REGISTER_SCATTER_ND_ADD_SUB_GPU
#undef REGISTER_SCATTER_ND_MIN_MAX
#undef REGISTER_SCATTER_ND_MIN_MAX_CPU
#undef REGISTER_SCATTER_ND_MIN_MAX_GPU
#undef REGISTER_SCATTER_ND_UPDATE #undef REGISTER_SCATTER_ND_UPDATE
#undef REGISTER_SCATTER_ND_UPDATE_CPU #undef REGISTER_SCATTER_ND_UPDATE_CPU
#undef REGISTER_SCATTER_ND_UPDATE_GPU #undef REGISTER_SCATTER_ND_UPDATE_GPU
@ -513,9 +584,13 @@ TF_CALL_complex128(REGISTER_SCATTER_ND_TENSOR_GPU);
#undef REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE
#undef REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE #undef REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE
#undef REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE #undef REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE
#undef REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE
#undef REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE
#undef REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU #undef REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU
#undef REGISTER_SCATTER_ND_TENSOR_ADD_GPU #undef REGISTER_SCATTER_ND_TENSOR_ADD_GPU
#undef REGISTER_SCATTER_ND_TENSOR_SUB_GPU #undef REGISTER_SCATTER_ND_TENSOR_SUB_GPU
#undef REGISTER_SCATTER_ND_TENSOR_MIN_GPU
#undef REGISTER_SCATTER_ND_TENSOR_MAX_GPU
#undef REGISTER_SCATTER_ND_TENSOR_GPU #undef REGISTER_SCATTER_ND_TENSOR_GPU
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@ -771,16 +846,28 @@ namespace functor {
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD); \ DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD); \
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB) DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB)
#define DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, Index) \
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MIN); \
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MAX)
#define DECLARE_GPU_SPECS(T) \ #define DECLARE_GPU_SPECS(T) \
DECLARE_GPU_SPECS_INDEX(T, int32); \ DECLARE_GPU_SPECS_INDEX(T, int32); \
DECLARE_GPU_SPECS_INDEX(T, int64) DECLARE_GPU_SPECS_INDEX(T, int64)
#define DECLARE_GPU_SPECS_MIN_MAX(T) \
DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, int32); \
DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, int64)
TF_CALL_int32(DECLARE_GPU_SPECS); TF_CALL_int32(DECLARE_GPU_SPECS);
TF_CALL_int32(DECLARE_GPU_SPECS_MIN_MAX);
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS_MIN_MAX);
TF_CALL_complex64(DECLARE_GPU_SPECS); TF_CALL_complex64(DECLARE_GPU_SPECS);
TF_CALL_complex128(DECLARE_GPU_SPECS); TF_CALL_complex128(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPECS_MIN_MAX
#undef DECLARE_GPU_SPECS #undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_SPECS_INDEX_MIN_MAX
#undef DECLARE_GPU_SPECS_INDEX #undef DECLARE_GPU_SPECS_INDEX
#undef DECLARE_GPU_SPECS_INDEX_OP #undef DECLARE_GPU_SPECS_INDEX_OP

View File

@ -37,7 +37,7 @@ class OpKernelContext;
namespace scatter_nd_op { namespace scatter_nd_op {
enum class UpdateOp { ASSIGN, ADD, SUB }; enum class UpdateOp { ASSIGN, ADD, SUB, MIN, MAX };
} // namespace scatter_nd_op } // namespace scatter_nd_op

View File

@ -47,38 +47,57 @@ class OpKernelContext;
// Specialization of UpdateExecutor to CPU // Specialization of UpdateExecutor to CPU
namespace update_executor { namespace update_executor {
template <typename Input, typename Update, typename Output, template <typename T, typename Input, typename Update, typename Output,
scatter_nd_op::UpdateOp OP> scatter_nd_op::UpdateOp OP>
class UpdateExecutor { class UpdateExecutor {
public: public:
EIGEN_STRONG_INLINE static void Execute(Input value, Update update, EIGEN_STRONG_INLINE static void Execute(const T& device, Input value,
Output output); Update update, Output output);
}; };
template <typename Input, typename Update, typename Output> template <typename T, typename Input, typename Update, typename Output>
class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::ASSIGN> { class UpdateExecutor<T, Input, Update, Output,
scatter_nd_op::UpdateOp::ASSIGN> {
public: public:
EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update, EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
Output output) { Update update, Output output) {
output = update; output.device(device) = update;
} }
}; };
template <typename Input, typename Update, typename Output> template <typename T, typename Input, typename Update, typename Output>
class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::ADD> { class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::ADD> {
public: public:
EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update, EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
Output output) { Update update, Output output) {
output += update; output.device(device) += update;
} }
}; };
template <typename Input, typename Update, typename Output> template <typename T, typename Input, typename Update, typename Output>
class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::SUB> { class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::SUB> {
public: public:
EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update, EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
Output output) { Update update, Output output) {
output -= update; output.device(device) -= update;
}
};
template <typename T, typename Input, typename Update, typename Output>
class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MIN> {
public:
EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
Update update, Output output) {
output.device(device) = output.cwiseMin(update);
}
};
template <typename T, typename Input, typename Update, typename Output>
class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MAX> {
public:
EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
Update update, Output output) {
output.device(device) = output.cwiseMax(update);
} }
}; };
@ -125,11 +144,12 @@ struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
break; break;
} else { } else {
auto input_chip = Toutput.template chip<0>(i); auto input_chip = Toutput.template chip<0>(i);
auto output_chip = input_chip.device(d); auto output_chip = input_chip;
auto update_chip = Tupdates.template chip<0>(loc); auto update_chip = Tupdates.template chip<0>(loc);
update_executor::UpdateExecutor< update_executor::UpdateExecutor<
decltype(input_chip), decltype(update_chip), decltype(output_chip), CPUDevice, decltype(input_chip), decltype(update_chip),
OP>::Execute(input_chip, update_chip, output_chip); decltype(output_chip), OP>::Execute(d, input_chip, update_chip,
output_chip);
} }
} }
@ -159,11 +179,18 @@ struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \ REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB); REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB);
#define REGISTER_SCATTER_ND_MIN_MAX(type) \
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MAX); \
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MIN);
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE); TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE);
REGISTER_SCATTER_ND_INDEX(tstring, scatter_nd_op::UpdateOp::ADD); REGISTER_SCATTER_ND_INDEX(tstring, scatter_nd_op::UpdateOp::ADD);
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH); TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX);
TF_CALL_bool(REGISTER_SCATTER_ND_MATH); TF_CALL_bool(REGISTER_SCATTER_ND_MATH);
#undef REGISTER_SCATTER_ND_MATH #undef REGISTER_SCATTER_ND_MATH
#undef REGISTER_SCATTER_ND_MIN_MAX
#undef REGISTER_SCATTER_ND_UPDATE #undef REGISTER_SCATTER_ND_UPDATE
#undef REGISTER_SCATTER_ND_INDEX #undef REGISTER_SCATTER_ND_INDEX
#undef REGISTER_SCATTER_ND_FULL #undef REGISTER_SCATTER_ND_FULL
@ -209,11 +236,12 @@ struct ScatterNdFunctor<SYCLDevice, T, Index, OP, IXDIM> {
break; break;
} else { } else {
auto input_chip = Toutput.template chip<0>(i); auto input_chip = Toutput.template chip<0>(i);
auto output_chip = input_chip.device(d); auto output_chip = input_chip;
auto update_chip = Tupdates.template chip<0>(loc); auto update_chip = Tupdates.template chip<0>(loc);
update_executor::UpdateExecutor< update_executor::UpdateExecutor<
decltype(input_chip), decltype(update_chip), decltype(output_chip), SYCLDevice, decltype(input_chip), decltype(update_chip),
OP>::Execute(input_chip, update_chip, output_chip); decltype(output_chip), OP>::Execute(d, input_chip, update_chip,
output_chip);
} }
} }
@ -241,7 +269,9 @@ struct ScatterNdFunctor<SYCLDevice, T, Index, OP, IXDIM> {
#define REGISTER_SCATTER_ND_MATH_SYCL(type) \ #define REGISTER_SCATTER_ND_MATH_SYCL(type) \
REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ADD); \ REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ADD); \
REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::SUB); REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::SUB); \
REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::MIN); \
REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::MAX);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL) TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL)
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_MATH_SYCL) TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_MATH_SYCL)

View File

@ -55,6 +55,20 @@ struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> {
} }
}; };
template <typename T>
struct LeftUpdate<T, scatter_nd_op::UpdateOp::MAX> {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
CudaAtomicMax(out, val);
}
};
template <typename T>
struct LeftUpdate<T, scatter_nd_op::UpdateOp::MIN> {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
CudaAtomicMin(out, val);
}
};
// Specializations for std::complex, updating real and imaginary part // Specializations for std::complex, updating real and imaginary part
// individually. Even though this is not an atomic op anymore, it is safe // individually. Even though this is not an atomic op anymore, it is safe
// because there is only one type of op per kernel. // because there is only one type of op per kernel.
@ -166,20 +180,33 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {
#define DECLARE_GPU_SPECS_INDEX(T, Index) \ #define DECLARE_GPU_SPECS_INDEX(T, Index) \
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \ DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD); \ DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD); \
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB) DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB);
#define DECLARE_GPU_SPECS_INDEX_MINMAX(T, Index) \
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MAX) \
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MIN);
#define DECLARE_GPU_SPECS(T) \ #define DECLARE_GPU_SPECS(T) \
DECLARE_GPU_SPECS_INDEX(T, int32); \ DECLARE_GPU_SPECS_INDEX(T, int32); \
DECLARE_GPU_SPECS_INDEX(T, int64) DECLARE_GPU_SPECS_INDEX(T, int64)
#define DECLARE_GPU_SPECS_MINMAX(T) \
DECLARE_GPU_SPECS_INDEX_MINMAX(T, int32); \
DECLARE_GPU_SPECS_INDEX_MINMAX(T, int64)
TF_CALL_int32(DECLARE_GPU_SPECS); TF_CALL_int32(DECLARE_GPU_SPECS);
TF_CALL_int32(DECLARE_GPU_SPECS_MINMAX);
TF_CALL_int64(DECLARE_GPU_SPECS); TF_CALL_int64(DECLARE_GPU_SPECS);
TF_CALL_int64(DECLARE_GPU_SPECS_MINMAX);
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS_MINMAX);
TF_CALL_complex64(DECLARE_GPU_SPECS); TF_CALL_complex64(DECLARE_GPU_SPECS);
TF_CALL_complex128(DECLARE_GPU_SPECS); TF_CALL_complex128(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPECS #undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_SPECS_MINMAX
#undef DECLARE_GPU_SPECS_INDEX #undef DECLARE_GPU_SPECS_INDEX
#undef DECLARE_GPU_SPECS_INDEX_MINMAX
#undef DECLARE_GPU_SPECS_INDEX_OP #undef DECLARE_GPU_SPECS_INDEX_OP
} // namespace tensorflow } // namespace tensorflow

View File

@ -3117,6 +3117,24 @@ REGISTER_OP("TensorScatterSub")
.Attr("Tindices: {int32, int64}") .Attr("Tindices: {int32, int64}")
.SetShapeFn(ScatterNdTensorShape); .SetShapeFn(ScatterNdTensorShape);
REGISTER_OP("TensorScatterMin")
.Input("tensor: T")
.Input("indices: Tindices")
.Input("updates: T")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.SetShapeFn(ScatterNdTensorShape);
REGISTER_OP("TensorScatterMax")
.Input("tensor: T")
.Input("indices: Tindices")
.Input("updates: T")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.SetShapeFn(ScatterNdTensorShape);
REGISTER_OP("ScatterNdNonAliasingAdd") REGISTER_OP("ScatterNdNonAliasingAdd")
.Input("input: T") .Input("input: T")
.Input("indices: Tindices") .Input("indices: Tindices")

View File

@ -240,6 +240,24 @@ REGISTER_OP("ResourceScatterNdSub")
.Attr("use_locking: bool = true") .Attr("use_locking: bool = true")
.SetShapeFn(shape_inference::ScatterNdUpdateShape); .SetShapeFn(shape_inference::ScatterNdUpdateShape);
REGISTER_OP("ResourceScatterNdMin")
.Input("ref: resource")
.Input("indices: Tindices")
.Input("updates: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
REGISTER_OP("ResourceScatterNdMax")
.Input("ref: resource")
.Input("indices: Tindices")
.Input("updates: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
REGISTER_OP("ScatterNdAdd") REGISTER_OP("ScatterNdAdd")
.Input("ref: Ref(T)") .Input("ref: Ref(T)")
.Input("indices: Tindices") .Input("indices: Tindices")
@ -260,6 +278,26 @@ REGISTER_OP("ScatterNdSub")
.Attr("use_locking: bool = false") .Attr("use_locking: bool = false")
.SetShapeFn(shape_inference::ScatterNdUpdateShape); .SetShapeFn(shape_inference::ScatterNdUpdateShape);
REGISTER_OP("ScatterNdMax")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
.Input("updates: T")
.Output("output_ref: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
REGISTER_OP("ScatterNdMin")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
.Input("updates: T")
.Output("output_ref: Ref(T)")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
REGISTER_OP("CountUpTo") REGISTER_OP("CountUpTo")
.Input("ref: Ref(T)") .Input("ref: Ref(T)")
.Output("output: T") .Output("output: T")

View File

@ -96,6 +96,14 @@ def _NumpyDiv(ref, indices, updates):
return _NumpyScatterNd(ref, indices, updates, lambda p, u: p / u) return _NumpyScatterNd(ref, indices, updates, lambda p, u: p / u)
def _NumpyMin(ref, indices, updates):
return _NumpyScatterNd(ref, indices, updates, np.minimum)
def _NumpyMax(ref, indices, updates):
return _NumpyScatterNd(ref, indices, updates, np.maximum)
class StatefulScatterNdTest(test.TestCase): class StatefulScatterNdTest(test.TestCase):
def _VariableRankTest(self, def _VariableRankTest(self,
@ -253,6 +261,8 @@ class StatefulScatterNdTest(test.TestCase):
"""This tests scatter_add using indices that repeat.""" """This tests scatter_add using indices that repeat."""
self._ScatterRepeatIndicesTest(_NumpyAdd, state_ops.scatter_nd_add) self._ScatterRepeatIndicesTest(_NumpyAdd, state_ops.scatter_nd_add)
self._ScatterRepeatIndicesTest(_NumpySub, state_ops.scatter_nd_sub) self._ScatterRepeatIndicesTest(_NumpySub, state_ops.scatter_nd_sub)
self._ScatterRepeatIndicesTest(_NumpyMin, state_ops.scatter_nd_min)
self._ScatterRepeatIndicesTest(_NumpyMax, state_ops.scatter_nd_max)
# TODO(ebrevdo): Re-enable when we need ScatterNdMul and ScatterNdDiv. # TODO(ebrevdo): Re-enable when we need ScatterNdMul and ScatterNdDiv.
# self._ScatterRepeatIndicesTest(_NumpyMul, state_ops.scatter_nd_mul) # self._ScatterRepeatIndicesTest(_NumpyMul, state_ops.scatter_nd_mul)
# self._ScatterRepeatIndicesTest(_NumpyDiv, state_ops.scatter_nd_div) # self._ScatterRepeatIndicesTest(_NumpyDiv, state_ops.scatter_nd_div)
@ -276,6 +286,7 @@ class StatefulScatterNdTest(test.TestCase):
# scatter_nd ops is under control. # scatter_nd ops is under control.
# tf.scatter_nd_mul, tf.scatter_nd_div, # tf.scatter_nd_mul, tf.scatter_nd_div,
for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub, for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub,
state_ops.scatter_nd_min, state_ops.scatter_nd_max,
state_ops.scatter_nd_update): state_ops.scatter_nd_update):
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32) params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
updates = np.array([-3, -4, -5]).astype(np.float32) updates = np.array([-3, -4, -5]).astype(np.float32)
@ -763,6 +774,22 @@ class ScatterNdTensorTest(test.TestCase):
self.assertLess(err_added_wrt_updates, 2e-4) self.assertLess(err_added_wrt_updates, 2e-4)
self.assertLess(err_subbed_wrt_updates, 2e-4) self.assertLess(err_subbed_wrt_updates, 2e-4)
@test_util.run_in_graph_and_eager_modes
def testUpdateMinMax(self):
indices = constant_op.constant([[4], [3], [1], [7]])
updates = constant_op.constant([0, 2, -1, 1.2], dtype=dtypes.float32)
t = array_ops.ones([8], dtype=dtypes.float32)
assigned = array_ops.tensor_scatter_update(t, indices, updates)
min_result = array_ops.tensor_scatter_min(t, indices, updates)
max_result = array_ops.tensor_scatter_max(t, indices, updates)
self.assertAllEqual(assigned,
constant_op.constant([1, -1, 1, 2, 0, 1, 1, 1.2]))
self.assertAllEqual(min_result,
constant_op.constant([1, -1, 1, 1, 0, 1, 1, 1]))
self.assertAllEqual(max_result,
constant_op.constant([1, 1, 1, 2, 1, 1, 1, 1.2]))
def testTensorScatterUpdateWithForwarding(self): def testTensorScatterUpdateWithForwarding(self):
@def_function.function @def_function.function
def _TestFn(): def _TestFn():

View File

@ -1202,6 +1202,78 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
self.handle, indices, ops.convert_to_tensor(updates, self.dtype), self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
name=name)) name=name))
def scatter_nd_max(self, indices, updates, name=None):
"""Updates this variable with the max of `tf.IndexedSlices` and itself.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
```
See `tf.scatter_nd` for more details about how to make updates to
slices.
Args:
indices: The indices to be used in the operation.
updates: The values to be used in the operation.
name: the name of the operation.
Returns:
The updated variable.
"""
return self._lazy_read(
gen_state_ops.resource_scatter_nd_max(
self.handle,
indices,
ops.convert_to_tensor(updates, self.dtype),
name=name))
def scatter_nd_min(self, indices, updates, name=None):
"""Updates this variable with the min of `tf.IndexedSlices` and itself.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
```
See `tf.scatter_nd` for more details about how to make updates to
slices.
Args:
indices: The indices to be used in the operation.
updates: The values to be used in the operation.
name: the name of the operation.
Returns:
The updated variable.
"""
return self._lazy_read(
gen_state_ops.resource_scatter_nd_min(
self.handle,
indices,
ops.convert_to_tensor(updates, self.dtype),
name=name))
def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
end_mask, ellipsis_mask, new_axis_mask, end_mask, ellipsis_mask, new_axis_mask,
shrink_axis_mask): shrink_axis_mask):
@ -1949,6 +2021,14 @@ class _UnreadVariable(BaseResourceVariable):
return super(_UnreadVariable, self).scatter_nd_update(indices, updates, return super(_UnreadVariable, self).scatter_nd_update(indices, updates,
name) name)
def scatter_nd_max(self, indices, updates, name=None):
with ops.control_dependencies([self._parent_op]):
return super(_UnreadVariable, self).scatter_nd_max(indices, updates, name)
def scatter_nd_min(self, indices, updates, name=None):
with ops.control_dependencies([self._parent_op]):
return super(_UnreadVariable, self).scatter_nd_min(indices, updates, name)
@property @property
def op(self): def op(self):
"""The op for this variable.""" """The op for this variable."""

View File

@ -2467,6 +2467,72 @@ class RefVariable(VariableV1, core.Tensor):
return gen_state_ops.scatter_nd_update( return gen_state_ops.scatter_nd_update(
self._variable, indices, updates, use_locking=True, name=name) self._variable, indices, updates, use_locking=True, name=name)
def scatter_nd_max(self, indices, updates, name=None):
"""Updates this variable with the max of `tf.IndexedSlices` and itself.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
```
See `tf.scatter_nd` for more details about how to make updates to
slices.
Args:
indices: The indices to be used in the operation.
updates: The values to be used in the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered addition has completed.
"""
return gen_state_ops.scatter_nd_max(
self._variable, indices, updates, use_locking=True, name=name)
def scatter_nd_min(self, indices, updates, name=None):
"""Updates this variable with the min of `tf.IndexedSlices` and itself.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
```
See `tf.scatter_nd` for more details about how to make updates to
slices.
Args:
indices: The indices to be used in the operation.
updates: The values to be used in the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered addition has completed.
"""
return gen_state_ops.scatter_nd_min(
self._variable, indices, updates, use_locking=True, name=name)
def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
end_mask, ellipsis_mask, new_axis_mask, end_mask, ellipsis_mask, new_axis_mask,
shrink_axis_mask): shrink_axis_mask):

View File

@ -2344,6 +2344,14 @@ tf_module {
name: "tensor_scatter_nd_add" name: "tensor_scatter_nd_add"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "tensor_scatter_nd_max"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "tensor_scatter_nd_min"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "tensor_scatter_nd_sub" name: "tensor_scatter_nd_sub"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -3520,6 +3520,14 @@ tf_module {
name: "ResourceScatterNdAdd" name: "ResourceScatterNdAdd"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
} }
member_method {
name: "ResourceScatterNdMax"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "ResourceScatterNdMin"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method { member_method {
name: "ResourceScatterNdSub" name: "ResourceScatterNdSub"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
@ -3788,6 +3796,14 @@ tf_module {
name: "ScatterNdAdd" name: "ScatterNdAdd"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
} }
member_method {
name: "ScatterNdMax"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "ScatterNdMin"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method { member_method {
name: "ScatterNdNonAliasingAdd" name: "ScatterNdNonAliasingAdd"
argspec: "args=[\'input\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'input\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -4864,6 +4880,14 @@ tf_module {
name: "TensorScatterAdd" name: "TensorScatterAdd"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "TensorScatterMax"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "TensorScatterMin"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "TensorScatterSub" name: "TensorScatterSub"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1068,6 +1068,14 @@ tf_module {
name: "tensor_scatter_nd_add" name: "tensor_scatter_nd_add"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "tensor_scatter_nd_max"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "tensor_scatter_nd_min"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "tensor_scatter_nd_sub" name: "tensor_scatter_nd_sub"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -3520,6 +3520,14 @@ tf_module {
name: "ResourceScatterNdAdd" name: "ResourceScatterNdAdd"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
} }
member_method {
name: "ResourceScatterNdMax"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "ResourceScatterNdMin"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method { member_method {
name: "ResourceScatterNdSub" name: "ResourceScatterNdSub"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
@ -3788,6 +3796,14 @@ tf_module {
name: "ScatterNdAdd" name: "ScatterNdAdd"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
} }
member_method {
name: "ScatterNdMax"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "ScatterNdMin"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method { member_method {
name: "ScatterNdNonAliasingAdd" name: "ScatterNdNonAliasingAdd"
argspec: "args=[\'input\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'input\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -4864,6 +4880,14 @@ tf_module {
name: "TensorScatterAdd" name: "TensorScatterAdd"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "TensorScatterMax"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "TensorScatterMin"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "TensorScatterSub" name: "TensorScatterSub"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1128,6 +1128,10 @@ renames = {
'tf.compat.v1.scatter_nd_add', 'tf.compat.v1.scatter_nd_add',
'tf.scatter_nd_sub': 'tf.scatter_nd_sub':
'tf.compat.v1.scatter_nd_sub', 'tf.compat.v1.scatter_nd_sub',
'tf.scatter_nd_max':
'tf.compat.v1.scatter_nd_max',
'tf.scatter_nd_min':
'tf.compat.v1.scatter_nd_min',
'tf.scatter_nd_update': 'tf.scatter_nd_update':
'tf.compat.v1.scatter_nd_update', 'tf.compat.v1.scatter_nd_update',
'tf.scatter_sub': 'tf.scatter_sub':