Add Registration for non-stateful scatter_nd_min and scatter_nd_max.
Manual import of #26923. Closes #26923. Fixes #20402 PiperOrigin-RevId: 314561995 Change-Id: Ibe1629d5ef769713d9058a6a72dc1421aa252ec4
This commit is contained in:
parent
0836e5e4ea
commit
23384f2ef1
|
@ -0,0 +1,31 @@
|
|||
op {
|
||||
graph_op_name: "ResourceScatterNdMax"
|
||||
in_arg {
|
||||
name: "ref"
|
||||
description: <<END
|
||||
A resource handle. Must be from a VarHandleOp.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "indices"
|
||||
description: <<END
|
||||
A Tensor. Must be one of the following types: int32, int64.
|
||||
A tensor of indices into ref.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "updates"
|
||||
description: <<END
|
||||
A Tensor. Must have the same type as ref. A tensor of
|
||||
values whose element wise max is taken with ref
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
An optional bool. Defaults to True. If True, the assignment will
|
||||
be protected by a lock; otherwise the behavior is undefined,
|
||||
but may exhibit less contention.
|
||||
END
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
op {
|
||||
graph_op_name: "ResourceScatterNdMin"
|
||||
in_arg {
|
||||
name: "ref"
|
||||
description: <<END
|
||||
A resource handle. Must be from a VarHandleOp.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "indices"
|
||||
description: <<END
|
||||
A Tensor. Must be one of the following types: int32, int64.
|
||||
A tensor of indices into ref.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "updates"
|
||||
description: <<END
|
||||
A Tensor. Must have the same type as ref. A tensor of
|
||||
values whose element wise min is taken with ref.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
An optional bool. Defaults to True. If True, the assignment will
|
||||
be protected by a lock; otherwise the behavior is undefined,
|
||||
but may exhibit less contention.
|
||||
END
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
op {
|
||||
graph_op_name: "ScatterNdMax"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "ref"
|
||||
description: <<END
|
||||
A mutable Tensor. Should be from a Variable node.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "indices"
|
||||
description: <<END
|
||||
A Tensor. Must be one of the following types: int32, int64.
|
||||
A tensor of indices into ref.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "updates"
|
||||
description: <<END
|
||||
A Tensor. Must have the same type as ref. A tensor of updated values
|
||||
to add to ref.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output_ref"
|
||||
description: <<END
|
||||
Same as ref. Returned as a convenience for operations that want
|
||||
to use the updated values after the update is done.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
An optional bool. Defaults to True. If True, the assignment will
|
||||
be protected by a lock; otherwise the behavior is undefined,
|
||||
but may exhibit less contention.
|
||||
END
|
||||
}
|
||||
summary: "Computes element-wise maximum."
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
op {
|
||||
graph_op_name: "ScatterNdMin"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "ref"
|
||||
description: <<END
|
||||
A mutable Tensor. Should be from a Variable node.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "indices"
|
||||
description: <<END
|
||||
A Tensor. Must be one of the following types: int32, int64.
|
||||
A tensor of indices into ref.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "updates"
|
||||
description: <<END
|
||||
A Tensor. Must have the same type as ref. A tensor of updated values
|
||||
to add to ref.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output_ref"
|
||||
description: <<END
|
||||
Same as ref. Returned as a convenience for operations that want
|
||||
to use the updated values after the update is done.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
An optional bool. Defaults to True. If True, the assignment will
|
||||
be protected by a lock; otherwise the behavior is undefined,
|
||||
but may exhibit less contention.
|
||||
END
|
||||
}
|
||||
summary: "Computes element-wise minimum."
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
op {
|
||||
graph_op_name: "TensorScatterMax"
|
||||
in_arg {
|
||||
name: "tensor"
|
||||
description: <<END
|
||||
Tensor to update.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "indices"
|
||||
description: <<END
|
||||
Index tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "updates"
|
||||
description: <<END
|
||||
Updates to scatter into output.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
A new tensor copied from tensor whose values are element-wise maximum between tensor and updates according to the indices.
|
||||
END
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
op {
|
||||
graph_op_name: "TensorScatterMin"
|
||||
in_arg {
|
||||
name: "tensor"
|
||||
description: <<END
|
||||
Tensor to update.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "indices"
|
||||
description: <<END
|
||||
Index tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "updates"
|
||||
description: <<END
|
||||
Updates to scatter into output.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
A new tensor copied from tensor whose values are element-wise minimum between tensor and updates according to the indices.
|
||||
END
|
||||
}
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
op {
|
||||
graph_op_name: "ResourceScatterNdMax"
|
||||
visibility: HIDDEN
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
op {
|
||||
graph_op_name: "ResourceScatterNdMin"
|
||||
visibility: HIDDEN
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
op {
|
||||
graph_op_name: "TensorScatterMax"
|
||||
endpoint {
|
||||
name: "tensor_scatter_nd_max"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
op {
|
||||
graph_op_name: "TensorScatterMin"
|
||||
endpoint {
|
||||
name: "tensor_scatter_nd_min"
|
||||
}
|
||||
}
|
|
@ -368,6 +368,16 @@ class ScatterNdUpdateOp : public OpKernel {
|
|||
REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
|
||||
type, dev, "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN);
|
||||
|
||||
#define REGISTER_SCATTER_ND_MIN_MAX(type, dev) \
|
||||
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMax", \
|
||||
scatter_nd_op::UpdateOp::MAX); \
|
||||
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMin", \
|
||||
scatter_nd_op::UpdateOp::MIN); \
|
||||
REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
|
||||
type, dev, "ResourceScatterNdMin", scatter_nd_op::UpdateOp::MIN); \
|
||||
REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
|
||||
type, dev, "ResourceScatterNdMax", scatter_nd_op::UpdateOp::MAX);
|
||||
|
||||
// Registers CPU kernels.
|
||||
#define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \
|
||||
REGISTER_SCATTER_ND_ADD_SUB(type, CPU);
|
||||
|
@ -375,6 +385,9 @@ class ScatterNdUpdateOp : public OpKernel {
|
|||
#define REGISTER_SCATTER_ND_UPDATE_CPU(type) \
|
||||
REGISTER_SCATTER_ND_UPDATE(type, CPU);
|
||||
|
||||
#define REGISTER_SCATTER_ND_MIN_MAX_CPU(type) \
|
||||
REGISTER_SCATTER_ND_MIN_MAX(type, CPU);
|
||||
|
||||
#define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);
|
||||
#define REGISTER_SCATTER_ND_GPU(type) REGISTER_SCATTER_ND(type, GPU);
|
||||
|
||||
|
@ -386,6 +399,7 @@ TF_CALL_tstring(REGISTER_SCATTER_ND_UPDATE_CPU);
|
|||
TF_CALL_bool(REGISTER_SCATTER_ND_ADD_SUB_CPU);
|
||||
TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_CPU);
|
||||
TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX_CPU);
|
||||
|
||||
#define REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, index_type, \
|
||||
dev) \
|
||||
|
@ -412,6 +426,22 @@ TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
|
|||
TensorScatterOp<dev##Device, type, index_type, \
|
||||
scatter_nd_op::UpdateOp::SUB>)
|
||||
|
||||
#define REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, index_type, dev) \
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorScatterMin") \
|
||||
.Device(DEVICE_##dev) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tindices"), \
|
||||
TensorScatterOp<dev##Device, type, index_type, \
|
||||
scatter_nd_op::UpdateOp::MIN>)
|
||||
|
||||
#define REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, index_type, dev) \
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorScatterMax") \
|
||||
.Device(DEVICE_##dev) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tindices"), \
|
||||
TensorScatterOp<dev##Device, type, index_type, \
|
||||
scatter_nd_op::UpdateOp::MAX>)
|
||||
|
||||
#define REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type) \
|
||||
REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int32, CPU); \
|
||||
REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE(type, int64, CPU);
|
||||
|
@ -424,6 +454,14 @@ TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
|
|||
REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, CPU); \
|
||||
REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64, CPU);
|
||||
|
||||
#define REGISTER_SCATTER_ND_TENSOR_MIN_CPU(type) \
|
||||
REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int32, CPU); \
|
||||
REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int64, CPU);
|
||||
|
||||
#define REGISTER_SCATTER_ND_TENSOR_MAX_CPU(type) \
|
||||
REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int32, CPU); \
|
||||
REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int64, CPU);
|
||||
|
||||
#define REGISTER_SCATTER_ND_TENSOR_CPU(type) \
|
||||
REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU(type); \
|
||||
REGISTER_SCATTER_ND_TENSOR_ADD_CPU(type); \
|
||||
|
@ -431,6 +469,9 @@ TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
|
|||
|
||||
// Register TensorScatterUpdate/Add/Sub for all number types.
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_CPU);
|
||||
// Register min/max operations only for Real number types
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_MIN_CPU);
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_TENSOR_MAX_CPU);
|
||||
// Register only TensorScatterUpdate for string/bool types as well.
|
||||
TF_CALL_tstring(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
|
||||
TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
|
||||
|
@ -446,6 +487,9 @@ TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
|
|||
#define REGISTER_SCATTER_ND_UPDATE_GPU(type) \
|
||||
REGISTER_SCATTER_ND_UPDATE(type, GPU);
|
||||
|
||||
#define REGISTER_SCATTER_ND_MIN_MAX_GPU(type) \
|
||||
REGISTER_SCATTER_ND_MIN_MAX(type, GPU);
|
||||
|
||||
#define REGISTER_SCATTER_ND_ALL_GPU(type) \
|
||||
REGISTER_SCATTER_ND_ADD_SUB_GPU(type); \
|
||||
REGISTER_SCATTER_ND_UPDATE_GPU(type); \
|
||||
|
@ -453,8 +497,11 @@ TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
|
|||
|
||||
// TODO(b/155931747): Use HostMemory for int32
|
||||
TF_CALL_int32(REGISTER_SCATTER_ND_ALL_GPU);
|
||||
TF_CALL_int32(REGISTER_SCATTER_ND_MIN_MAX_GPU);
|
||||
TF_CALL_int64(REGISTER_SCATTER_ND_ALL_GPU);
|
||||
TF_CALL_int64(REGISTER_SCATTER_ND_MIN_MAX_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_ALL_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX_GPU);
|
||||
TF_CALL_complex64(REGISTER_SCATTER_ND_ALL_GPU);
|
||||
TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU);
|
||||
|
||||
|
@ -467,12 +514,19 @@ TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU);
|
|||
#define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \
|
||||
REGISTER_SCATTER_ND_UPDATE(type, SYCL);
|
||||
|
||||
#define REGISTER_SCATTER_ND_MIN_MAX_GPU(type) \
|
||||
REGISTER_SCATTER_ND_MIN_MAX(type, SYCL);
|
||||
|
||||
TF_CALL_int32(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
|
||||
TF_CALL_int32(REGISTER_SCATTER_ND_UPDATE_SYCL);
|
||||
TF_CALL_int32(REGISTER_SCATTER_ND_MIN_MAX_SYCL);
|
||||
TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_SYCL);
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL);
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_MIN_MAX_SYCL);
|
||||
|
||||
#undef REGISTER_SCATTER_ND_ADD_SUB_SYCL
|
||||
#undef REGISTER_SCATTER_ND_MIN_MAX_SYCL
|
||||
#undef REGISTER_SCATTER_ND_UPDATE_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
|
@ -488,13 +542,27 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL);
|
|||
REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int32, GPU); \
|
||||
REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE(type, int64, GPU);
|
||||
|
||||
#define REGISTER_SCATTER_ND_TENSOR_MIN_GPU(type) \
|
||||
REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int32, GPU); \
|
||||
REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE(type, int64, GPU);
|
||||
|
||||
#define REGISTER_SCATTER_ND_TENSOR_MAX_GPU(type) \
|
||||
REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int32, GPU); \
|
||||
REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE(type, int64, GPU);
|
||||
|
||||
#define REGISTER_SCATTER_ND_TENSOR_GPU(type) \
|
||||
REGISTER_SCATTER_ND_TENSOR_ADD_GPU(type); \
|
||||
REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU(type); \
|
||||
REGISTER_SCATTER_ND_TENSOR_SUB_GPU(type);
|
||||
|
||||
#define REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX(type) \
|
||||
REGISTER_SCATTER_ND_TENSOR_MIN_GPU(type); \
|
||||
REGISTER_SCATTER_ND_TENSOR_MAX_GPU(type);
|
||||
|
||||
TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU);
|
||||
TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX);
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX);
|
||||
TF_CALL_complex64(REGISTER_SCATTER_ND_TENSOR_GPU);
|
||||
TF_CALL_complex128(REGISTER_SCATTER_ND_TENSOR_GPU);
|
||||
|
||||
|
@ -502,6 +570,9 @@ TF_CALL_complex128(REGISTER_SCATTER_ND_TENSOR_GPU);
|
|||
#undef REGISTER_SCATTER_ND_ADD_SUB
|
||||
#undef REGISTER_SCATTER_ND_ADD_SUB_CPU
|
||||
#undef REGISTER_SCATTER_ND_ADD_SUB_GPU
|
||||
#undef REGISTER_SCATTER_ND_MIN_MAX
|
||||
#undef REGISTER_SCATTER_ND_MIN_MAX_CPU
|
||||
#undef REGISTER_SCATTER_ND_MIN_MAX_GPU
|
||||
#undef REGISTER_SCATTER_ND_UPDATE
|
||||
#undef REGISTER_SCATTER_ND_UPDATE_CPU
|
||||
#undef REGISTER_SCATTER_ND_UPDATE_GPU
|
||||
|
@ -513,9 +584,13 @@ TF_CALL_complex128(REGISTER_SCATTER_ND_TENSOR_GPU);
|
|||
#undef REGISTER_SCATTER_ND_TENSOR_UPDATE_TYPE_INDEX_TYPE
|
||||
#undef REGISTER_SCATTER_ND_TENSOR_ADD_TYPE_INDEX_TYPE
|
||||
#undef REGISTER_SCATTER_ND_TENSOR_SUB_TYPE_INDEX_TYPE
|
||||
#undef REGISTER_SCATTER_ND_TENSOR_MIN_TYPE_INDEX_TYPE
|
||||
#undef REGISTER_SCATTER_ND_TENSOR_MAX_TYPE_INDEX_TYPE
|
||||
#undef REGISTER_SCATTER_ND_TENSOR_UPDATE_GPU
|
||||
#undef REGISTER_SCATTER_ND_TENSOR_ADD_GPU
|
||||
#undef REGISTER_SCATTER_ND_TENSOR_SUB_GPU
|
||||
#undef REGISTER_SCATTER_ND_TENSOR_MIN_GPU
|
||||
#undef REGISTER_SCATTER_ND_TENSOR_MAX_GPU
|
||||
#undef REGISTER_SCATTER_ND_TENSOR_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -771,16 +846,28 @@ namespace functor {
|
|||
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD); \
|
||||
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB)
|
||||
|
||||
#define DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, Index) \
|
||||
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MIN); \
|
||||
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MAX)
|
||||
|
||||
#define DECLARE_GPU_SPECS(T) \
|
||||
DECLARE_GPU_SPECS_INDEX(T, int32); \
|
||||
DECLARE_GPU_SPECS_INDEX(T, int64)
|
||||
|
||||
#define DECLARE_GPU_SPECS_MIN_MAX(T) \
|
||||
DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, int32); \
|
||||
DECLARE_GPU_SPECS_INDEX_MIN_MAX(T, int64)
|
||||
|
||||
TF_CALL_int32(DECLARE_GPU_SPECS);
|
||||
TF_CALL_int32(DECLARE_GPU_SPECS_MIN_MAX);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS_MIN_MAX);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPECS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS_MIN_MAX
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_SPECS_INDEX_MIN_MAX
|
||||
#undef DECLARE_GPU_SPECS_INDEX
|
||||
#undef DECLARE_GPU_SPECS_INDEX_OP
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class OpKernelContext;
|
|||
|
||||
namespace scatter_nd_op {
|
||||
|
||||
enum class UpdateOp { ASSIGN, ADD, SUB };
|
||||
enum class UpdateOp { ASSIGN, ADD, SUB, MIN, MAX };
|
||||
|
||||
} // namespace scatter_nd_op
|
||||
|
||||
|
|
|
@ -47,38 +47,57 @@ class OpKernelContext;
|
|||
// Specialization of UpdateExecutor to CPU
|
||||
namespace update_executor {
|
||||
|
||||
template <typename Input, typename Update, typename Output,
|
||||
template <typename T, typename Input, typename Update, typename Output,
|
||||
scatter_nd_op::UpdateOp OP>
|
||||
class UpdateExecutor {
|
||||
public:
|
||||
EIGEN_STRONG_INLINE static void Execute(Input value, Update update,
|
||||
Output output);
|
||||
EIGEN_STRONG_INLINE static void Execute(const T& device, Input value,
|
||||
Update update, Output output);
|
||||
};
|
||||
|
||||
template <typename Input, typename Update, typename Output>
|
||||
class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::ASSIGN> {
|
||||
template <typename T, typename Input, typename Update, typename Output>
|
||||
class UpdateExecutor<T, Input, Update, Output,
|
||||
scatter_nd_op::UpdateOp::ASSIGN> {
|
||||
public:
|
||||
EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update,
|
||||
Output output) {
|
||||
output = update;
|
||||
EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
|
||||
Update update, Output output) {
|
||||
output.device(device) = update;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Input, typename Update, typename Output>
|
||||
class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::ADD> {
|
||||
template <typename T, typename Input, typename Update, typename Output>
|
||||
class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::ADD> {
|
||||
public:
|
||||
EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update,
|
||||
Output output) {
|
||||
output += update;
|
||||
EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
|
||||
Update update, Output output) {
|
||||
output.device(device) += update;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Input, typename Update, typename Output>
|
||||
class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::SUB> {
|
||||
template <typename T, typename Input, typename Update, typename Output>
|
||||
class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::SUB> {
|
||||
public:
|
||||
EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update,
|
||||
Output output) {
|
||||
output -= update;
|
||||
EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
|
||||
Update update, Output output) {
|
||||
output.device(device) -= update;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Input, typename Update, typename Output>
|
||||
class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MIN> {
|
||||
public:
|
||||
EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
|
||||
Update update, Output output) {
|
||||
output.device(device) = output.cwiseMin(update);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Input, typename Update, typename Output>
|
||||
class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MAX> {
|
||||
public:
|
||||
EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
|
||||
Update update, Output output) {
|
||||
output.device(device) = output.cwiseMax(update);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -125,11 +144,12 @@ struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
|
|||
break;
|
||||
} else {
|
||||
auto input_chip = Toutput.template chip<0>(i);
|
||||
auto output_chip = input_chip.device(d);
|
||||
auto output_chip = input_chip;
|
||||
auto update_chip = Tupdates.template chip<0>(loc);
|
||||
update_executor::UpdateExecutor<
|
||||
decltype(input_chip), decltype(update_chip), decltype(output_chip),
|
||||
OP>::Execute(input_chip, update_chip, output_chip);
|
||||
CPUDevice, decltype(input_chip), decltype(update_chip),
|
||||
decltype(output_chip), OP>::Execute(d, input_chip, update_chip,
|
||||
output_chip);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -159,11 +179,18 @@ struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
|
|||
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \
|
||||
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB);
|
||||
|
||||
#define REGISTER_SCATTER_ND_MIN_MAX(type) \
|
||||
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MAX); \
|
||||
REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MIN);
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE);
|
||||
REGISTER_SCATTER_ND_INDEX(tstring, scatter_nd_op::UpdateOp::ADD);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH);
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX);
|
||||
TF_CALL_bool(REGISTER_SCATTER_ND_MATH);
|
||||
|
||||
#undef REGISTER_SCATTER_ND_MATH
|
||||
#undef REGISTER_SCATTER_ND_MIN_MAX
|
||||
#undef REGISTER_SCATTER_ND_UPDATE
|
||||
#undef REGISTER_SCATTER_ND_INDEX
|
||||
#undef REGISTER_SCATTER_ND_FULL
|
||||
|
@ -209,11 +236,12 @@ struct ScatterNdFunctor<SYCLDevice, T, Index, OP, IXDIM> {
|
|||
break;
|
||||
} else {
|
||||
auto input_chip = Toutput.template chip<0>(i);
|
||||
auto output_chip = input_chip.device(d);
|
||||
auto output_chip = input_chip;
|
||||
auto update_chip = Tupdates.template chip<0>(loc);
|
||||
update_executor::UpdateExecutor<
|
||||
decltype(input_chip), decltype(update_chip), decltype(output_chip),
|
||||
OP>::Execute(input_chip, update_chip, output_chip);
|
||||
SYCLDevice, decltype(input_chip), decltype(update_chip),
|
||||
decltype(output_chip), OP>::Execute(d, input_chip, update_chip,
|
||||
output_chip);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -241,7 +269,9 @@ struct ScatterNdFunctor<SYCLDevice, T, Index, OP, IXDIM> {
|
|||
|
||||
#define REGISTER_SCATTER_ND_MATH_SYCL(type) \
|
||||
REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ADD); \
|
||||
REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::SUB);
|
||||
REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::SUB); \
|
||||
REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::MIN); \
|
||||
REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::MAX);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL)
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_MATH_SYCL)
|
||||
|
|
|
@ -55,6 +55,20 @@ struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LeftUpdate<T, scatter_nd_op::UpdateOp::MAX> {
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
|
||||
CudaAtomicMax(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LeftUpdate<T, scatter_nd_op::UpdateOp::MIN> {
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
|
||||
CudaAtomicMin(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
// Specializations for std::complex, updating real and imaginary part
|
||||
// individually. Even though this is not an atomic op anymore, it is safe
|
||||
// because there is only one type of op per kernel.
|
||||
|
@ -166,20 +180,33 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {
|
|||
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
|
||||
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
|
||||
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD); \
|
||||
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB)
|
||||
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB);
|
||||
|
||||
#define DECLARE_GPU_SPECS_INDEX_MINMAX(T, Index) \
|
||||
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MAX) \
|
||||
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MIN);
|
||||
|
||||
#define DECLARE_GPU_SPECS(T) \
|
||||
DECLARE_GPU_SPECS_INDEX(T, int32); \
|
||||
DECLARE_GPU_SPECS_INDEX(T, int64)
|
||||
|
||||
#define DECLARE_GPU_SPECS_MINMAX(T) \
|
||||
DECLARE_GPU_SPECS_INDEX_MINMAX(T, int32); \
|
||||
DECLARE_GPU_SPECS_INDEX_MINMAX(T, int64)
|
||||
|
||||
TF_CALL_int32(DECLARE_GPU_SPECS);
|
||||
TF_CALL_int32(DECLARE_GPU_SPECS_MINMAX);
|
||||
TF_CALL_int64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_int64(DECLARE_GPU_SPECS_MINMAX);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS_MINMAX);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPECS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_SPECS_MINMAX
|
||||
#undef DECLARE_GPU_SPECS_INDEX
|
||||
#undef DECLARE_GPU_SPECS_INDEX_MINMAX
|
||||
#undef DECLARE_GPU_SPECS_INDEX_OP
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -3117,6 +3117,24 @@ REGISTER_OP("TensorScatterSub")
|
|||
.Attr("Tindices: {int32, int64}")
|
||||
.SetShapeFn(ScatterNdTensorShape);
|
||||
|
||||
REGISTER_OP("TensorScatterMin")
|
||||
.Input("tensor: T")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.SetShapeFn(ScatterNdTensorShape);
|
||||
|
||||
REGISTER_OP("TensorScatterMax")
|
||||
.Input("tensor: T")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.SetShapeFn(ScatterNdTensorShape);
|
||||
|
||||
REGISTER_OP("ScatterNdNonAliasingAdd")
|
||||
.Input("input: T")
|
||||
.Input("indices: Tindices")
|
||||
|
|
|
@ -240,6 +240,24 @@ REGISTER_OP("ResourceScatterNdSub")
|
|||
.Attr("use_locking: bool = true")
|
||||
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("ResourceScatterNdMin")
|
||||
.Input("ref: resource")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = true")
|
||||
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("ResourceScatterNdMax")
|
||||
.Input("ref: resource")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = true")
|
||||
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("ScatterNdAdd")
|
||||
.Input("ref: Ref(T)")
|
||||
.Input("indices: Tindices")
|
||||
|
@ -260,6 +278,26 @@ REGISTER_OP("ScatterNdSub")
|
|||
.Attr("use_locking: bool = false")
|
||||
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("ScatterNdMax")
|
||||
.Input("ref: Ref(T)")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Output("output_ref: Ref(T)")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("ScatterNdMin")
|
||||
.Input("ref: Ref(T)")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Output("output_ref: Ref(T)")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("CountUpTo")
|
||||
.Input("ref: Ref(T)")
|
||||
.Output("output: T")
|
||||
|
|
|
@ -96,6 +96,14 @@ def _NumpyDiv(ref, indices, updates):
|
|||
return _NumpyScatterNd(ref, indices, updates, lambda p, u: p / u)
|
||||
|
||||
|
||||
def _NumpyMin(ref, indices, updates):
|
||||
return _NumpyScatterNd(ref, indices, updates, np.minimum)
|
||||
|
||||
|
||||
def _NumpyMax(ref, indices, updates):
|
||||
return _NumpyScatterNd(ref, indices, updates, np.maximum)
|
||||
|
||||
|
||||
class StatefulScatterNdTest(test.TestCase):
|
||||
|
||||
def _VariableRankTest(self,
|
||||
|
@ -253,6 +261,8 @@ class StatefulScatterNdTest(test.TestCase):
|
|||
"""This tests scatter_add using indices that repeat."""
|
||||
self._ScatterRepeatIndicesTest(_NumpyAdd, state_ops.scatter_nd_add)
|
||||
self._ScatterRepeatIndicesTest(_NumpySub, state_ops.scatter_nd_sub)
|
||||
self._ScatterRepeatIndicesTest(_NumpyMin, state_ops.scatter_nd_min)
|
||||
self._ScatterRepeatIndicesTest(_NumpyMax, state_ops.scatter_nd_max)
|
||||
# TODO(ebrevdo): Re-enable when we need ScatterNdMul and ScatterNdDiv.
|
||||
# self._ScatterRepeatIndicesTest(_NumpyMul, state_ops.scatter_nd_mul)
|
||||
# self._ScatterRepeatIndicesTest(_NumpyDiv, state_ops.scatter_nd_div)
|
||||
|
@ -276,6 +286,7 @@ class StatefulScatterNdTest(test.TestCase):
|
|||
# scatter_nd ops is under control.
|
||||
# tf.scatter_nd_mul, tf.scatter_nd_div,
|
||||
for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub,
|
||||
state_ops.scatter_nd_min, state_ops.scatter_nd_max,
|
||||
state_ops.scatter_nd_update):
|
||||
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
|
||||
updates = np.array([-3, -4, -5]).astype(np.float32)
|
||||
|
@ -763,6 +774,22 @@ class ScatterNdTensorTest(test.TestCase):
|
|||
self.assertLess(err_added_wrt_updates, 2e-4)
|
||||
self.assertLess(err_subbed_wrt_updates, 2e-4)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testUpdateMinMax(self):
|
||||
indices = constant_op.constant([[4], [3], [1], [7]])
|
||||
updates = constant_op.constant([0, 2, -1, 1.2], dtype=dtypes.float32)
|
||||
t = array_ops.ones([8], dtype=dtypes.float32)
|
||||
assigned = array_ops.tensor_scatter_update(t, indices, updates)
|
||||
min_result = array_ops.tensor_scatter_min(t, indices, updates)
|
||||
max_result = array_ops.tensor_scatter_max(t, indices, updates)
|
||||
|
||||
self.assertAllEqual(assigned,
|
||||
constant_op.constant([1, -1, 1, 2, 0, 1, 1, 1.2]))
|
||||
self.assertAllEqual(min_result,
|
||||
constant_op.constant([1, -1, 1, 1, 0, 1, 1, 1]))
|
||||
self.assertAllEqual(max_result,
|
||||
constant_op.constant([1, 1, 1, 2, 1, 1, 1, 1.2]))
|
||||
|
||||
def testTensorScatterUpdateWithForwarding(self):
|
||||
@def_function.function
|
||||
def _TestFn():
|
||||
|
|
|
@ -1202,6 +1202,78 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
|||
self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
|
||||
name=name))
|
||||
|
||||
def scatter_nd_max(self, indices, updates, name=None):
|
||||
"""Updates this variable with the max of `tf.IndexedSlices` and itself.
|
||||
|
||||
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `ref`.
|
||||
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
|
||||
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||
dimension of `ref`.
|
||||
|
||||
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||
|
||||
```
|
||||
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
```
|
||||
|
||||
See `tf.scatter_nd` for more details about how to make updates to
|
||||
slices.
|
||||
|
||||
Args:
|
||||
indices: The indices to be used in the operation.
|
||||
updates: The values to be used in the operation.
|
||||
name: the name of the operation.
|
||||
|
||||
Returns:
|
||||
The updated variable.
|
||||
"""
|
||||
return self._lazy_read(
|
||||
gen_state_ops.resource_scatter_nd_max(
|
||||
self.handle,
|
||||
indices,
|
||||
ops.convert_to_tensor(updates, self.dtype),
|
||||
name=name))
|
||||
|
||||
def scatter_nd_min(self, indices, updates, name=None):
|
||||
"""Updates this variable with the min of `tf.IndexedSlices` and itself.
|
||||
|
||||
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `ref`.
|
||||
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
|
||||
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||
dimension of `ref`.
|
||||
|
||||
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||
|
||||
```
|
||||
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
```
|
||||
|
||||
See `tf.scatter_nd` for more details about how to make updates to
|
||||
slices.
|
||||
|
||||
Args:
|
||||
indices: The indices to be used in the operation.
|
||||
updates: The values to be used in the operation.
|
||||
name: the name of the operation.
|
||||
|
||||
Returns:
|
||||
The updated variable.
|
||||
"""
|
||||
return self._lazy_read(
|
||||
gen_state_ops.resource_scatter_nd_min(
|
||||
self.handle,
|
||||
indices,
|
||||
ops.convert_to_tensor(updates, self.dtype),
|
||||
name=name))
|
||||
|
||||
def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
|
||||
end_mask, ellipsis_mask, new_axis_mask,
|
||||
shrink_axis_mask):
|
||||
|
@ -1949,6 +2021,14 @@ class _UnreadVariable(BaseResourceVariable):
|
|||
return super(_UnreadVariable, self).scatter_nd_update(indices, updates,
|
||||
name)
|
||||
|
||||
def scatter_nd_max(self, indices, updates, name=None):
|
||||
with ops.control_dependencies([self._parent_op]):
|
||||
return super(_UnreadVariable, self).scatter_nd_max(indices, updates, name)
|
||||
|
||||
def scatter_nd_min(self, indices, updates, name=None):
|
||||
with ops.control_dependencies([self._parent_op]):
|
||||
return super(_UnreadVariable, self).scatter_nd_min(indices, updates, name)
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
"""The op for this variable."""
|
||||
|
|
|
@ -2467,6 +2467,72 @@ class RefVariable(VariableV1, core.Tensor):
|
|||
return gen_state_ops.scatter_nd_update(
|
||||
self._variable, indices, updates, use_locking=True, name=name)
|
||||
|
||||
def scatter_nd_max(self, indices, updates, name=None):
|
||||
"""Updates this variable with the max of `tf.IndexedSlices` and itself.
|
||||
|
||||
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `ref`.
|
||||
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
|
||||
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||
dimension of `ref`.
|
||||
|
||||
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||
|
||||
```
|
||||
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
```
|
||||
|
||||
See `tf.scatter_nd` for more details about how to make updates to
|
||||
slices.
|
||||
|
||||
Args:
|
||||
indices: The indices to be used in the operation.
|
||||
updates: The values to be used in the operation.
|
||||
name: the name of the operation.
|
||||
|
||||
Returns:
|
||||
A `Tensor` that will hold the new value of this variable after
|
||||
the scattered addition has completed.
|
||||
"""
|
||||
return gen_state_ops.scatter_nd_max(
|
||||
self._variable, indices, updates, use_locking=True, name=name)
|
||||
|
||||
def scatter_nd_min(self, indices, updates, name=None):
|
||||
"""Updates this variable with the min of `tf.IndexedSlices` and itself.
|
||||
|
||||
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `ref`.
|
||||
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||||
|
||||
The innermost dimension of `indices` (with length `K`) corresponds to
|
||||
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||||
dimension of `ref`.
|
||||
|
||||
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||||
|
||||
```
|
||||
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
|
||||
```
|
||||
|
||||
See `tf.scatter_nd` for more details about how to make updates to
|
||||
slices.
|
||||
|
||||
Args:
|
||||
indices: The indices to be used in the operation.
|
||||
updates: The values to be used in the operation.
|
||||
name: the name of the operation.
|
||||
|
||||
Returns:
|
||||
A `Tensor` that will hold the new value of this variable after
|
||||
the scattered addition has completed.
|
||||
"""
|
||||
return gen_state_ops.scatter_nd_min(
|
||||
self._variable, indices, updates, use_locking=True, name=name)
|
||||
|
||||
def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
|
||||
end_mask, ellipsis_mask, new_axis_mask,
|
||||
shrink_axis_mask):
|
||||
|
|
|
@ -2344,6 +2344,14 @@ tf_module {
|
|||
name: "tensor_scatter_nd_add"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_max"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_min"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_sub"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
|
|
@ -3520,6 +3520,14 @@ tf_module {
|
|||
name: "ResourceScatterNdAdd"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdMax"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdMin"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdSub"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
|
@ -3788,6 +3796,14 @@ tf_module {
|
|||
name: "ScatterNdAdd"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdMax"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdMin"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdNonAliasingAdd"
|
||||
argspec: "args=[\'input\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -4864,6 +4880,14 @@ tf_module {
|
|||
name: "TensorScatterAdd"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorScatterMax"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorScatterMin"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorScatterSub"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
|
|
@ -1068,6 +1068,14 @@ tf_module {
|
|||
name: "tensor_scatter_nd_add"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_max"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_min"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_sub"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
|
|
@ -3520,6 +3520,14 @@ tf_module {
|
|||
name: "ResourceScatterNdAdd"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdMax"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdMin"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdSub"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
|
@ -3788,6 +3796,14 @@ tf_module {
|
|||
name: "ScatterNdAdd"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdMax"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdMin"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdNonAliasingAdd"
|
||||
argspec: "args=[\'input\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -4864,6 +4880,14 @@ tf_module {
|
|||
name: "TensorScatterAdd"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorScatterMax"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorScatterMin"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorScatterSub"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
|
|
@ -1128,6 +1128,10 @@ renames = {
|
|||
'tf.compat.v1.scatter_nd_add',
|
||||
'tf.scatter_nd_sub':
|
||||
'tf.compat.v1.scatter_nd_sub',
|
||||
'tf.scatter_nd_max':
|
||||
'tf.compat.v1.scatter_nd_max',
|
||||
'tf.scatter_nd_min':
|
||||
'tf.compat.v1.scatter_nd_min',
|
||||
'tf.scatter_nd_update':
|
||||
'tf.compat.v1.scatter_nd_update',
|
||||
'tf.scatter_sub':
|
||||
|
|
Loading…
Reference in New Issue