Correctly handle empty matrices in tf.linalg.svd.
PiperOrigin-RevId: 311411299 Change-Id: Ie5440ad4593291409f801fb174fbac3120db0eb7
This commit is contained in:
parent
2046f7c450
commit
18c0da1024
tensorflow
@ -3792,7 +3792,9 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "svd_op",
|
||||
prefix = "svd_op",
|
||||
deps = LINALG_DEPS,
|
||||
deps = LINALG_DEPS + if_cuda([
|
||||
":eye_functor",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/eye_functor.h"
|
||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -390,8 +391,22 @@ class SvdOpGpu : public AsyncOpKernel {
|
||||
done);
|
||||
|
||||
if (n == 0 || m == 0) {
|
||||
// If X is an empty matrix (0 rows, 0 col), X * X' == X.
|
||||
// Therefore, we return X.
|
||||
if (n == m || !compute_uv_ || !full_matrices_) {
|
||||
// S, U, and V are all empty. Nothing to do.
|
||||
done();
|
||||
return;
|
||||
}
|
||||
auto device = context->eigen_device<GPUDevice>();
|
||||
functor::EyeFunctor<GPUDevice, Scalar> eye;
|
||||
if (m > 0) {
|
||||
// Return a full canonical basis for the column space.
|
||||
auto outputU_reshaped = outputU->flat_inner_dims<Scalar, 3>();
|
||||
eye(device, outputU_reshaped);
|
||||
} else if (n > 0) {
|
||||
// Return a full canonical basis for the row space.
|
||||
auto outputV_reshaped = outputV->flat_inner_dims<Scalar, 3>();
|
||||
eye(device, outputV_reshaped);
|
||||
}
|
||||
done();
|
||||
return;
|
||||
}
|
||||
|
@ -83,16 +83,29 @@ class SvdOp : public LinearAlgebraOp<Scalar> {
|
||||
|
||||
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
|
||||
MatrixMaps* outputs) final {
|
||||
int64 n = inputs[0].cols();
|
||||
int64 m = inputs[0].rows();
|
||||
const bool empty = (m == 0 || n == 0);
|
||||
int options = 0; // Don't compute singular vectors;
|
||||
if (compute_uv_) {
|
||||
options = full_matrices_ ? Eigen::ComputeFullU | Eigen::ComputeFullV
|
||||
: Eigen::ComputeThinU | Eigen::ComputeThinV;
|
||||
}
|
||||
Eigen::BDCSVD<Matrix> svd(inputs[0], options);
|
||||
outputs->at(0) = svd.singularValues().template cast<Scalar>();
|
||||
if (compute_uv_) {
|
||||
outputs->at(1) = svd.matrixU();
|
||||
outputs->at(2) = svd.matrixV();
|
||||
if (!empty) {
|
||||
Eigen::BDCSVD<Matrix> svd(inputs[0], options);
|
||||
outputs->at(0) = svd.singularValues().template cast<Scalar>();
|
||||
if (compute_uv_) {
|
||||
outputs->at(1) = svd.matrixU();
|
||||
outputs->at(2) = svd.matrixV();
|
||||
}
|
||||
} else if (compute_uv_ && full_matrices_) {
|
||||
// For an empty matrix where only one dimension is zero, we still set
|
||||
// U or V to the unit matrix for the dimension that is non-zero.
|
||||
if (m > 0) {
|
||||
outputs->at(1) = Matrix::Identity(m, m);
|
||||
} else {
|
||||
outputs->at(2) = Matrix::Identity(n, n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -93,7 +93,8 @@ def _GetSvdOpTest(dtype_, shape_, use_static_shape_, compute_uv_,
|
||||
full_matrices_):
|
||||
|
||||
def CompareSingularValues(self, x, y, tol):
|
||||
self.assertAllClose(x, y, atol=(x[0] + y[0]) * tol)
|
||||
atol = (x[0] + y[0]) * tol if len(x) else tol
|
||||
self.assertAllClose(x, y, atol=atol)
|
||||
|
||||
def CompareSingularVectors(self, x, y, rank, tol):
|
||||
# We only compare the first 'rank' singular vectors since the
|
||||
@ -374,8 +375,8 @@ if __name__ == "__main__":
|
||||
for compute_uv in False, True:
|
||||
for full_matrices in False, True:
|
||||
for dtype in dtypes_to_test:
|
||||
for rows in 1, 2, 5, 10, 32, 100:
|
||||
for cols in 1, 2, 5, 10, 32, 100:
|
||||
for rows in 0, 1, 2, 5, 10, 32, 100:
|
||||
for cols in 0, 1, 2, 5, 10, 32, 100:
|
||||
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
|
||||
shape = batch_dims + (rows, cols)
|
||||
# TF2 does not support placeholders under eager so we skip it
|
||||
|
Loading…
Reference in New Issue
Block a user