diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 897db384b7e..17639bd8a75 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -71,7 +71,7 @@ def handle_options(func, x, axis, exclusive, reverse): class CumsumTest(xla_test.XLATestCase): - valid_dtypes = [np.float32] + valid_dtypes = [np.float32, np.int32] def axis_dtypes(self): return set(self.int_types).intersection([np.int32, np.int64]) @@ -149,7 +149,7 @@ class CumsumTest(xla_test.XLATestCase): class CumprodTest(xla_test.XLATestCase): - valid_dtypes = [np.float32] + valid_dtypes = [np.float32, np.int32] def axis_dtypes(self): return set(self.int_types).intersection([np.int32, np.int64]) diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index b5fd7850bfc..7f4fef146f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -39,8 +39,8 @@ namespace { // TODO(phawkins): implement double-sized windowed reductions in XLA and remove // the type constraint. -constexpr std::array kScanOpTypes = { - {DT_HALF, DT_BFLOAT16, DT_FLOAT}}; +constexpr std::array kScanOpTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_INT32}}; class ScanOp : public XlaOpKernel { public: