[TF/XLA] Support F64 conversion for tf.cumsum
PiperOrigin-RevId: 312785189 Change-Id: I88b4bfe7c2448218230c09eb11eb672e3a40a85a
This commit is contained in:
parent
431dc17adc
commit
0f178c3708
@ -36,10 +36,8 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// TODO(phawkins): implement double-sized windowed reductions in XLA and remove
|
constexpr std::array<DataType, 5> kScanOpTypes = {
|
||||||
// the type constraint.
|
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32}};
|
||||||
constexpr std::array<DataType, 4> kScanOpTypes = {
|
|
||||||
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_INT32}};
|
|
||||||
|
|
||||||
class ScanOp : public XlaOpKernel {
|
class ScanOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
|
@ -355,6 +355,15 @@ class DefFunctionTest(test.TestCase):
|
|||||||
self.assertAllClose([5.0, 5.0, 5.0], g())
|
self.assertAllClose([5.0, 5.0, 5.0], g())
|
||||||
self.assertAllClose(compiled_g(), g())
|
self.assertAllClose(compiled_g(), g())
|
||||||
|
|
||||||
|
def testCumsum(self):
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def f(x):
|
||||||
|
return math_ops.cumsum(x)
|
||||||
|
|
||||||
|
f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64)
|
||||||
|
self.assertAllClose([1.1, 3.3, 6.6], f(f64_input))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
Loading…
Reference in New Issue
Block a user