Add CompileTimeConstantInput to XlaGather.
PiperOrigin-RevId: 326175243 Change-Id: I8540c0d09709be40ff6d5bfaa50a4eac918a6627
This commit is contained in:
parent
af9cb379b6
commit
335b03c45c
@ -79,6 +79,25 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
args=(v,),
|
||||
expected=np.tile(v, (7, 42, 1, 1)))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testGather(self):
|
||||
operand = np.arange(10, dtype=np.int32).reshape([2, 5])
|
||||
start_indices = np.array([2], np.int32)
|
||||
slice_sizes = np.array([1, 3], np.int32)
|
||||
|
||||
def gather(operand, start_indices):
|
||||
dimension_numbers = xla_data_pb2.GatherDimensionNumbers()
|
||||
dimension_numbers.offset_dims.extend([1])
|
||||
dimension_numbers.collapsed_slice_dims.extend([0])
|
||||
dimension_numbers.start_index_map.extend([0])
|
||||
dimension_numbers.index_vector_dim = 1
|
||||
return xla.gather(operand, start_indices, dimension_numbers, slice_sizes)
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
gather,
|
||||
args=(operand, start_indices),
|
||||
expected=np.array([[5, 6, 7]]))
|
||||
|
||||
@test_util.disable_mlir_bridge('Dynamic result types not supported')
|
||||
def testShiftRightLogical(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
|
@ -49,7 +49,8 @@ class GatherOp : public XlaOpKernel {
|
||||
bool indices_are_sorted_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaGather"), GatherOp);
|
||||
REGISTER_XLA_OP(Name("XlaGather").CompileTimeConstantInput("slice_sizes"),
|
||||
GatherOp);
|
||||
|
||||
class ScatterOp : public XlaOpKernel {
|
||||
public:
|
||||
|
Loading…
Reference in New Issue
Block a user