[TF/XLA] Support F64 conversion for tf.cumsum

PiperOrigin-RevId: 312785189
Change-Id: I88b4bfe7c2448218230c09eb11eb672e3a40a85a
This commit is contained in:
George Karpenkov 2020-05-21 19:44:32 -07:00 committed by TensorFlower Gardener
parent 431dc17adc
commit 0f178c3708
2 changed files with 11 additions and 4 deletions

View File

@ -36,10 +36,8 @@ limitations under the License.
namespace tensorflow {
namespace {
// TODO(phawkins): implement double-sized windowed reductions in XLA and remove
// the type constraint.
constexpr std::array<DataType, 4> kScanOpTypes = {
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_INT32}};
constexpr std::array<DataType, 5> kScanOpTypes = {
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32}};
class ScanOp : public XlaOpKernel {
public:

View File

@ -355,6 +355,15 @@ class DefFunctionTest(test.TestCase):
self.assertAllClose([5.0, 5.0, 5.0], g())
self.assertAllClose(compiled_g(), g())
def testCumsum(self):
@def_function.function(experimental_compile=True)
def f(x):
return math_ops.cumsum(x)
f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64)
self.assertAllClose([1.1, 3.3, 6.6], f(f64_input))
if __name__ == '__main__':
ops.enable_eager_execution()