Add CompileTimeConstantInput to XlaGather.

PiperOrigin-RevId: 326175243
Change-Id: I8540c0d09709be40ff6d5bfaa50a4eac918a6627
This commit is contained in:
A. Unique TensorFlower 2020-08-11 23:08:41 -07:00 committed by TensorFlower Gardener
parent af9cb379b6
commit 335b03c45c
2 changed files with 21 additions and 1 deletions

View File

@ -79,6 +79,25 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
args=(v,),
expected=np.tile(v, (7, 42, 1, 1)))
@test_util.disable_mlir_bridge('Not supported yet')
def testGather(self):
operand = np.arange(10, dtype=np.int32).reshape([2, 5])
start_indices = np.array([2], np.int32)
slice_sizes = np.array([1, 3], np.int32)
def gather(operand, start_indices):
dimension_numbers = xla_data_pb2.GatherDimensionNumbers()
dimension_numbers.offset_dims.extend([1])
dimension_numbers.collapsed_slice_dims.extend([0])
dimension_numbers.start_index_map.extend([0])
dimension_numbers.index_vector_dim = 1
return xla.gather(operand, start_indices, dimension_numbers, slice_sizes)
self._assertOpOutputMatchesExpected(
gather,
args=(operand, start_indices),
expected=np.array([[5, 6, 7]]))
@test_util.disable_mlir_bridge('Dynamic result types not supported')
def testShiftRightLogical(self):
self._assertOpOutputMatchesExpected(

View File

@ -49,7 +49,8 @@ class GatherOp : public XlaOpKernel {
bool indices_are_sorted_;
};
REGISTER_XLA_OP(Name("XlaGather"), GatherOp);
REGISTER_XLA_OP(Name("XlaGather").CompileTimeConstantInput("slice_sizes"),
GatherOp);
class ScatterOp : public XlaOpKernel {
public: