[TF:XLA] Add support for tensor_scatter_{add,sub,update}
These are the same as scatter_nd, but take a tensor as an input to be scattered into instead of starting out with a zero tensor. PiperOrigin-RevId: 251471865
This commit is contained in:
parent
180f28a266
commit
854a1e4cb1
@ -190,5 +190,34 @@ class ScatterNdTest(xla_test.XLATestCase):
|
||||
self._runScatterNd(indices, updates, [6])
|
||||
|
||||
|
||||
class ScatterNdTensorTest(xla_test.XLATestCase):
|
||||
|
||||
def _runScatter(self, op):
|
||||
indices_np = np.array([[4], [3], [1], [7]], dtype=np.int32)
|
||||
updates_np = np.array([9, 10, 11, 12], dtype=np.float32)
|
||||
with self.session() as sess, self.test_scope():
|
||||
indices = array_ops.placeholder(indices_np.dtype, shape=indices_np.shape)
|
||||
updates = array_ops.placeholder(updates_np.dtype, shape=updates_np.shape)
|
||||
t = array_ops.ones([8], dtype=np.float32)
|
||||
|
||||
out = op(t, indices, updates)
|
||||
return sess.run(out, feed_dict={indices: indices_np, updates: updates_np})
|
||||
|
||||
def testAdd(self):
|
||||
self.assertAllEqual(
|
||||
self._runScatter(array_ops.tensor_scatter_add),
|
||||
np.array([1, 12, 1, 11, 10, 1, 1, 13], dtype=np.float32))
|
||||
|
||||
def testSub(self):
|
||||
self.assertAllEqual(
|
||||
self._runScatter(array_ops.tensor_scatter_sub),
|
||||
np.array([1, -10, 1, -9, -8, 1, 1, -11], dtype=np.float32))
|
||||
|
||||
def testUpdate(self):
|
||||
self.assertAllEqual(
|
||||
self._runScatter(array_ops.tensor_scatter_update),
|
||||
np.array([1, 11, 1, 10, 9, 1, 1, 12], dtype=np.float32))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -125,5 +125,80 @@ class ScatterNdOp : public XlaOpKernel {
|
||||
REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstantInput("shape"),
|
||||
ScatterNdOp);
|
||||
|
||||
void CompileTensorScatter(
|
||||
XlaOpKernelContext* context,
|
||||
const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
|
||||
combiner) {
|
||||
TensorShape buffer_shape = context->InputShape(0);
|
||||
TensorShape indices_shape = context->InputShape(1);
|
||||
TensorShape updates_shape = context->InputShape(2);
|
||||
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsVectorOrHigher(buffer_shape),
|
||||
errors::InvalidArgument("Output must be at least 1-D, ",
|
||||
"got shape: ", buffer_shape.DebugString()));
|
||||
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 &&
|
||||
updates_shape.num_elements() == 0),
|
||||
errors::InvalidArgument(
|
||||
"Indices and updates specified for empty output. indices shape: ",
|
||||
indices_shape.DebugString()));
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
context, ValidateUpdateShape(buffer_shape, indices_shape, updates_shape));
|
||||
|
||||
xla::XlaBuilder* builder = context->builder();
|
||||
auto buffer = context->Input(0);
|
||||
auto indices = context->Input(1);
|
||||
auto updates = context->Input(2);
|
||||
auto result = XlaScatter(buffer, updates, indices,
|
||||
/*indices_are_vectors=*/true, combiner, builder);
|
||||
OP_REQUIRES_OK(context, result.status());
|
||||
context->SetOutput(0, result.ValueOrDie());
|
||||
}
|
||||
|
||||
class TensorScatterAddOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorScatterAddOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
CompileTensorScatter(context,
|
||||
[](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
|
||||
return xla::Add(x, y);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
class TensorScatterSubOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorScatterSubOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
CompileTensorScatter(context,
|
||||
[](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
|
||||
return xla::Sub(x, y);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
class TensorScatterUpdateOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorScatterUpdateOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
CompileTensorScatter(
|
||||
context, [](xla::XlaOp, xla::XlaOp y, xla::XlaBuilder*) { return y; });
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorScatterAdd"), TensorScatterAddOp);
|
||||
REGISTER_XLA_OP(Name("TensorScatterSub"), TensorScatterSubOp);
|
||||
REGISTER_XLA_OP(Name("TensorScatterUpdate"), TensorScatterUpdateOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user