Support DT_INT32 for Cumsum/Cumprod.
PiperOrigin-RevId: 222429176
This commit is contained in:
parent
ed52111074
commit
4085979982
@ -71,7 +71,7 @@ def handle_options(func, x, axis, exclusive, reverse):
|
|||||||
|
|
||||||
class CumsumTest(xla_test.XLATestCase):
|
class CumsumTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
valid_dtypes = [np.float32]
|
valid_dtypes = [np.float32, np.int32]
|
||||||
|
|
||||||
def axis_dtypes(self):
|
def axis_dtypes(self):
|
||||||
return set(self.int_types).intersection([np.int32, np.int64])
|
return set(self.int_types).intersection([np.int32, np.int64])
|
||||||
@ -149,7 +149,7 @@ class CumsumTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
class CumprodTest(xla_test.XLATestCase):
|
class CumprodTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
valid_dtypes = [np.float32]
|
valid_dtypes = [np.float32, np.int32]
|
||||||
|
|
||||||
def axis_dtypes(self):
|
def axis_dtypes(self):
|
||||||
return set(self.int_types).intersection([np.int32, np.int64])
|
return set(self.int_types).intersection([np.int32, np.int64])
|
||||||
|
@ -39,8 +39,8 @@ namespace {
|
|||||||
|
|
||||||
// TODO(phawkins): implement double-sized windowed reductions in XLA and remove
|
// TODO(phawkins): implement double-sized windowed reductions in XLA and remove
|
||||||
// the type constraint.
|
// the type constraint.
|
||||||
constexpr std::array<DataType, 3> kScanOpTypes = {
|
constexpr std::array<DataType, 4> kScanOpTypes = {
|
||||||
{DT_HALF, DT_BFLOAT16, DT_FLOAT}};
|
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_INT32}};
|
||||||
|
|
||||||
class ScanOp : public XlaOpKernel {
|
class ScanOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
|
Loading…
Reference in New Issue
Block a user