Support DT_INT32 for Cumsum/Cumprod.
PiperOrigin-RevId: 222429176
This commit is contained in:
parent
ed52111074
commit
4085979982
@ -71,7 +71,7 @@ def handle_options(func, x, axis, exclusive, reverse):
|
||||
|
||||
class CumsumTest(xla_test.XLATestCase):
|
||||
|
||||
valid_dtypes = [np.float32]
|
||||
valid_dtypes = [np.float32, np.int32]
|
||||
|
||||
def axis_dtypes(self):
|
||||
return set(self.int_types).intersection([np.int32, np.int64])
|
||||
@ -149,7 +149,7 @@ class CumsumTest(xla_test.XLATestCase):
|
||||
|
||||
class CumprodTest(xla_test.XLATestCase):
|
||||
|
||||
valid_dtypes = [np.float32]
|
||||
valid_dtypes = [np.float32, np.int32]
|
||||
|
||||
def axis_dtypes(self):
|
||||
return set(self.int_types).intersection([np.int32, np.int64])
|
||||
|
@ -39,8 +39,8 @@ namespace {
|
||||
|
||||
// TODO(phawkins): implement double-sized windowed reductions in XLA and remove
|
||||
// the type constraint.
|
||||
constexpr std::array<DataType, 3> kScanOpTypes = {
|
||||
{DT_HALF, DT_BFLOAT16, DT_FLOAT}};
|
||||
constexpr std::array<DataType, 4> kScanOpTypes = {
|
||||
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_INT32}};
|
||||
|
||||
class ScanOp : public XlaOpKernel {
|
||||
public:
|
||||
|
Loading…
Reference in New Issue
Block a user