[TF/XLA] Support F64 conversion for tf.cumsum
PiperOrigin-RevId: 312785189 Change-Id: I88b4bfe7c2448218230c09eb11eb672e3a40a85a
This commit is contained in:
parent
431dc17adc
commit
0f178c3708
@ -36,10 +36,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// TODO(phawkins): implement double-sized windowed reductions in XLA and remove
|
||||
// the type constraint.
|
||||
constexpr std::array<DataType, 4> kScanOpTypes = {
|
||||
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_INT32}};
|
||||
constexpr std::array<DataType, 5> kScanOpTypes = {
|
||||
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32}};
|
||||
|
||||
class ScanOp : public XlaOpKernel {
|
||||
public:
|
||||
|
@ -355,6 +355,15 @@ class DefFunctionTest(test.TestCase):
|
||||
self.assertAllClose([5.0, 5.0, 5.0], g())
|
||||
self.assertAllClose(compiled_g(), g())
|
||||
|
||||
def testCumsum(self):
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def f(x):
|
||||
return math_ops.cumsum(x)
|
||||
|
||||
f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64)
|
||||
self.assertAllClose([1.1, 3.3, 6.6], f(f64_input))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
Reference in New Issue
Block a user