Fix output shapes in Svd op xla op kernel
TensorFlow shape inference uses zero size vector if full_matrices attribute is false. This is required to use this kernel from MLIR lowering without causing shape mismatch. PiperOrigin-RevId: 343984162 Change-Id: Iadb2db8e1aca2e3711cf03cecc347dd556360427
This commit is contained in:
parent
362fc65f75
commit
93d37e183b
@ -24,6 +24,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_linalg_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
@ -75,8 +76,8 @@ class SvdOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
no_uv_s_val, no_uv_u_val, no_uv_v_val = sess.run(
|
||||
[no_uv_s, no_uv_u, no_uv_v], feed_dict={x_tf: x_np})
|
||||
self.assertAllClose(no_uv_s_val, s_val, atol=1e-4, rtol=1e-4)
|
||||
self.assertEqual(no_uv_u_val, 0.0)
|
||||
self.assertEqual(no_uv_v_val, 0.0)
|
||||
self.assertEqual(no_uv_u_val.shape, tensor_shape.TensorShape([0]))
|
||||
self.assertEqual(no_uv_v_val.shape, tensor_shape.TensorShape([0]))
|
||||
|
||||
SIZES = [1, 2, 5, 10, 32, 64]
|
||||
DTYPES = [np.float32]
|
||||
|
@ -81,8 +81,10 @@ class SvdOp : public XlaOpKernel {
|
||||
ctx->SetOutput(1, result.u);
|
||||
ctx->SetOutput(2, result.v);
|
||||
} else {
|
||||
ctx->SetOutput(1, xla::ScalarLike(ctx->Input(0), 0.0));
|
||||
ctx->SetOutput(2, xla::ScalarLike(ctx->Input(0), 0.0));
|
||||
auto shape =
|
||||
xla::ShapeUtil::MakeShape(ctx->input_xla_type(0), /*dimensions=*/{0});
|
||||
ctx->SetOutput(1, xla::Zeros(ctx->builder(), shape));
|
||||
ctx->SetOutput(2, xla::Zeros(ctx->builder(), shape));
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user