[TF:XLA] Add support for tensor_scatter_{add,sub,update}

These are the same as scatter_nd, but take a tensor as an input to be scattered
into instead of starting out with a zero tensor.

PiperOrigin-RevId: 251471865
This commit is contained in:
Benjamin Kramer 2019-06-04 10:55:34 -07:00 committed by TensorFlower Gardener
parent 180f28a266
commit 854a1e4cb1
2 changed files with 104 additions and 0 deletions

View File

@ -190,5 +190,34 @@ class ScatterNdTest(xla_test.XLATestCase):
self._runScatterNd(indices, updates, [6])
class ScatterNdTensorTest(xla_test.XLATestCase):
def _runScatter(self, op):
indices_np = np.array([[4], [3], [1], [7]], dtype=np.int32)
updates_np = np.array([9, 10, 11, 12], dtype=np.float32)
with self.session() as sess, self.test_scope():
indices = array_ops.placeholder(indices_np.dtype, shape=indices_np.shape)
updates = array_ops.placeholder(updates_np.dtype, shape=updates_np.shape)
t = array_ops.ones([8], dtype=np.float32)
out = op(t, indices, updates)
return sess.run(out, feed_dict={indices: indices_np, updates: updates_np})
def testAdd(self):
self.assertAllEqual(
self._runScatter(array_ops.tensor_scatter_add),
np.array([1, 12, 1, 11, 10, 1, 1, 13], dtype=np.float32))
def testSub(self):
self.assertAllEqual(
self._runScatter(array_ops.tensor_scatter_sub),
np.array([1, -10, 1, -9, -8, 1, 1, -11], dtype=np.float32))
def testUpdate(self):
self.assertAllEqual(
self._runScatter(array_ops.tensor_scatter_update),
np.array([1, 11, 1, 10, 9, 1, 1, 12], dtype=np.float32))
if __name__ == "__main__":
test.main()

View File

@ -125,5 +125,80 @@ class ScatterNdOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstantInput("shape"),
ScatterNdOp);
void CompileTensorScatter(
XlaOpKernelContext* context,
const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
combiner) {
TensorShape buffer_shape = context->InputShape(0);
TensorShape indices_shape = context->InputShape(1);
TensorShape updates_shape = context->InputShape(2);
OP_REQUIRES(
context, TensorShapeUtils::IsVectorOrHigher(buffer_shape),
errors::InvalidArgument("Output must be at least 1-D, ",
"got shape: ", buffer_shape.DebugString()));
OP_REQUIRES(
context,
buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 &&
updates_shape.num_elements() == 0),
errors::InvalidArgument(
"Indices and updates specified for empty output. indices shape: ",
indices_shape.DebugString()));
OP_REQUIRES_OK(
context, ValidateUpdateShape(buffer_shape, indices_shape, updates_shape));
xla::XlaBuilder* builder = context->builder();
auto buffer = context->Input(0);
auto indices = context->Input(1);
auto updates = context->Input(2);
auto result = XlaScatter(buffer, updates, indices,
/*indices_are_vectors=*/true, combiner, builder);
OP_REQUIRES_OK(context, result.status());
context->SetOutput(0, result.ValueOrDie());
}
class TensorScatterAddOp : public XlaOpKernel {
public:
explicit TensorScatterAddOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
CompileTensorScatter(context,
[](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
return xla::Add(x, y);
});
}
};
class TensorScatterSubOp : public XlaOpKernel {
public:
explicit TensorScatterSubOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
CompileTensorScatter(context,
[](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
return xla::Sub(x, y);
});
}
};
class TensorScatterUpdateOp : public XlaOpKernel {
public:
explicit TensorScatterUpdateOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
CompileTensorScatter(
context, [](xla::XlaOp, xla::XlaOp y, xla::XlaBuilder*) { return y; });
}
};
REGISTER_XLA_OP(Name("TensorScatterAdd"), TensorScatterAddOp);
REGISTER_XLA_OP(Name("TensorScatterSub"), TensorScatterSubOp);
REGISTER_XLA_OP(Name("TensorScatterUpdate"), TensorScatterUpdateOp);
} // namespace
} // namespace tensorflow