diff --git a/tensorflow/compiler/tests/svd_op_test.py b/tensorflow/compiler/tests/svd_op_test.py index 7e05eeb4c0a..95266dea797 100644 --- a/tensorflow/compiler/tests/svd_op_test.py +++ b/tensorflow/compiler/tests/svd_op_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import linalg_ops @@ -75,8 +76,8 @@ class SvdOpTest(xla_test.XLATestCase, parameterized.TestCase): no_uv_s_val, no_uv_u_val, no_uv_v_val = sess.run( [no_uv_s, no_uv_u, no_uv_v], feed_dict={x_tf: x_np}) self.assertAllClose(no_uv_s_val, s_val, atol=1e-4, rtol=1e-4) - self.assertEqual(no_uv_u_val, 0.0) - self.assertEqual(no_uv_v_val, 0.0) + self.assertEqual(no_uv_u_val.shape, tensor_shape.TensorShape([0])) + self.assertEqual(no_uv_v_val.shape, tensor_shape.TensorShape([0])) SIZES = [1, 2, 5, 10, 32, 64] DTYPES = [np.float32] diff --git a/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc index 8e9ed35783f..b8b542c1b61 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc @@ -81,8 +81,10 @@ class SvdOp : public XlaOpKernel { ctx->SetOutput(1, result.u); ctx->SetOutput(2, result.v); } else { - ctx->SetOutput(1, xla::ScalarLike(ctx->Input(0), 0.0)); - ctx->SetOutput(2, xla::ScalarLike(ctx->Input(0), 0.0)); + auto shape = + xla::ShapeUtil::MakeShape(ctx->input_xla_type(0), /*dimensions=*/{0}); + ctx->SetOutput(1, xla::Zeros(ctx->builder(), shape)); + ctx->SetOutput(2, xla::Zeros(ctx->builder(), shape)); } }