Correctly handle empty matrices in tf.linalg.svd.

PiperOrigin-RevId: 311411299
Change-Id: Ie5440ad4593291409f801fb174fbac3120db0eb7
This commit is contained in:
A. Unique TensorFlower 2020-05-13 15:00:16 -07:00 committed by TensorFlower Gardener
parent 2046f7c450
commit 18c0da1024
4 changed files with 42 additions and 11 deletions
tensorflow

View File

@ -3792,7 +3792,9 @@ tf_kernel_library(
tf_kernel_library(
name = "svd_op",
prefix = "svd_op",
deps = LINALG_DEPS,
deps = LINALG_DEPS + if_cuda([
":eye_functor",
]),
)
tf_kernel_library(

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/eye_functor.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/lib/core/errors.h"
@ -390,8 +391,22 @@ class SvdOpGpu : public AsyncOpKernel {
done);
if (n == 0 || m == 0) {
// If X is an empty matrix (0 rows, 0 col), X * X' == X.
// Therefore, we return X.
if (n == m || !compute_uv_ || !full_matrices_) {
// S, U, and V are all empty. Nothing to do.
done();
return;
}
auto device = context->eigen_device<GPUDevice>();
functor::EyeFunctor<GPUDevice, Scalar> eye;
if (m > 0) {
// Return a full canonical basis for the column space.
auto outputU_reshaped = outputU->flat_inner_dims<Scalar, 3>();
eye(device, outputU_reshaped);
} else if (n > 0) {
// Return a full canonical basis for the row space.
auto outputV_reshaped = outputV->flat_inner_dims<Scalar, 3>();
eye(device, outputV_reshaped);
}
done();
return;
}

View File

@ -83,16 +83,29 @@ class SvdOp : public LinearAlgebraOp<Scalar> {
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
MatrixMaps* outputs) final {
int64 n = inputs[0].cols();
int64 m = inputs[0].rows();
const bool empty = (m == 0 || n == 0);
int options = 0; // Don't compute singular vectors;
if (compute_uv_) {
options = full_matrices_ ? Eigen::ComputeFullU | Eigen::ComputeFullV
: Eigen::ComputeThinU | Eigen::ComputeThinV;
}
Eigen::BDCSVD<Matrix> svd(inputs[0], options);
outputs->at(0) = svd.singularValues().template cast<Scalar>();
if (compute_uv_) {
outputs->at(1) = svd.matrixU();
outputs->at(2) = svd.matrixV();
if (!empty) {
Eigen::BDCSVD<Matrix> svd(inputs[0], options);
outputs->at(0) = svd.singularValues().template cast<Scalar>();
if (compute_uv_) {
outputs->at(1) = svd.matrixU();
outputs->at(2) = svd.matrixV();
}
} else if (compute_uv_ && full_matrices_) {
// For an empty matrix where only one dimension is zero, we still set
// U or V to the unit matrix for the dimension that is non-zero.
if (m > 0) {
outputs->at(1) = Matrix::Identity(m, m);
} else {
outputs->at(2) = Matrix::Identity(n, n);
}
}
}

View File

@ -93,7 +93,8 @@ def _GetSvdOpTest(dtype_, shape_, use_static_shape_, compute_uv_,
full_matrices_):
def CompareSingularValues(self, x, y, tol):
self.assertAllClose(x, y, atol=(x[0] + y[0]) * tol)
atol = (x[0] + y[0]) * tol if len(x) else tol
self.assertAllClose(x, y, atol=atol)
def CompareSingularVectors(self, x, y, rank, tol):
# We only compare the first 'rank' singular vectors since the
@ -374,8 +375,8 @@ if __name__ == "__main__":
for compute_uv in False, True:
for full_matrices in False, True:
for dtype in dtypes_to_test:
for rows in 1, 2, 5, 10, 32, 100:
for cols in 1, 2, 5, 10, 32, 100:
for rows in 0, 1, 2, 5, 10, 32, 100:
for cols in 0, 1, 2, 5, 10, 32, 100:
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
shape = batch_dims + (rows, cols)
# TF2 does not support placeholders under eager so we skip it