diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 8431724f438..beb8e7aa174 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -36,10 +36,8 @@ limitations under the License. namespace tensorflow { namespace { -// TODO(phawkins): implement double-sized windowed reductions in XLA and remove -// the type constraint. -constexpr std::array kScanOpTypes = { - {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_INT32}}; +constexpr std::array kScanOpTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32}}; class ScanOp : public XlaOpKernel { public: diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 0e89887647a..5fdf0487333 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -355,6 +355,15 @@ class DefFunctionTest(test.TestCase): self.assertAllClose([5.0, 5.0, 5.0], g()) self.assertAllClose(compiled_g(), g()) + def testCumsum(self): + + @def_function.function(experimental_compile=True) + def f(x): + return math_ops.cumsum(x) + + f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64) + self.assertAllClose([1.1, 3.3, 6.6], f(f64_input)) + if __name__ == '__main__': ops.enable_eager_execution()