Support DT_INT32 for Cumsum/Cumprod.

PiperOrigin-RevId: 222429176
This commit is contained in:
Tong Shen 2018-11-21 10:36:21 -08:00 committed by TensorFlower Gardener
parent ed52111074
commit 4085979982
2 changed files with 4 additions and 4 deletions

View File

@ -71,7 +71,7 @@ def handle_options(func, x, axis, exclusive, reverse):
class CumsumTest(xla_test.XLATestCase):
valid_dtypes = [np.float32]
valid_dtypes = [np.float32, np.int32]
def axis_dtypes(self):
return set(self.int_types).intersection([np.int32, np.int64])
@ -149,7 +149,7 @@ class CumsumTest(xla_test.XLATestCase):
class CumprodTest(xla_test.XLATestCase):
valid_dtypes = [np.float32]
valid_dtypes = [np.float32, np.int32]
def axis_dtypes(self):
return set(self.int_types).intersection([np.int32, np.int64])

View File

@ -39,8 +39,8 @@ namespace {
// TODO(phawkins): implement double-sized windowed reductions in XLA and remove
// the type constraint.
constexpr std::array<DataType, 3> kScanOpTypes = {
{DT_HALF, DT_BFLOAT16, DT_FLOAT}};
constexpr std::array<DataType, 4> kScanOpTypes = {
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_INT32}};
class ScanOp : public XlaOpKernel {
public: