Fix output shapes in Svd op xla op kernel

TensorFlow shape inference uses zero size vector if full_matrices attribute is false. This is required to use this kernel from MLIR lowering without causing shape mismatch.

PiperOrigin-RevId: 343984162
Change-Id: Iadb2db8e1aca2e3711cf03cecc347dd556360427
This commit is contained in:
Smit Hinsu 2020-11-23 21:09:08 -08:00 committed by TensorFlower Gardener
parent 362fc65f75
commit 93d37e183b
2 changed files with 7 additions and 4 deletions

View File

@ -24,6 +24,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import linalg_ops
@ -75,8 +76,8 @@ class SvdOpTest(xla_test.XLATestCase, parameterized.TestCase):
no_uv_s_val, no_uv_u_val, no_uv_v_val = sess.run(
[no_uv_s, no_uv_u, no_uv_v], feed_dict={x_tf: x_np})
self.assertAllClose(no_uv_s_val, s_val, atol=1e-4, rtol=1e-4)
self.assertEqual(no_uv_u_val, 0.0)
self.assertEqual(no_uv_v_val, 0.0)
self.assertEqual(no_uv_u_val.shape, tensor_shape.TensorShape([0]))
self.assertEqual(no_uv_v_val.shape, tensor_shape.TensorShape([0]))
SIZES = [1, 2, 5, 10, 32, 64]
DTYPES = [np.float32]

View File

@ -81,8 +81,10 @@ class SvdOp : public XlaOpKernel {
ctx->SetOutput(1, result.u);
ctx->SetOutput(2, result.v);
} else {
ctx->SetOutput(1, xla::ScalarLike(ctx->Input(0), 0.0));
ctx->SetOutput(2, xla::ScalarLike(ctx->Input(0), 0.0));
auto shape =
xla::ShapeUtil::MakeShape(ctx->input_xla_type(0), /*dimensions=*/{0});
ctx->SetOutput(1, xla::Zeros(ctx->builder(), shape));
ctx->SetOutput(2, xla::Zeros(ctx->builder(), shape));
}
}