[TF:XLA] Implement XlaSvdOp
PiperOrigin-RevId: 236936547
This commit is contained in:
parent
e26ab7b9e1
commit
9baeb353e1
@ -247,12 +247,23 @@ tf_xla_py_test(
|
||||
name = "self_adjoint_eig_op_test",
|
||||
size = "medium",
|
||||
srcs = ["self_adjoint_eig_op_test.py"],
|
||||
# TODO(kuny): remove it after b/124377352 is fixed.
|
||||
disabled_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
"cpu_ondemand",
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:map_fn",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "svd_op_test",
|
||||
size = "medium",
|
||||
srcs = ["svd_op_test.py"],
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
|
81
tensorflow/compiler/tests/svd_op_test.py
Normal file
81
tensorflow/compiler/tests/svd_op_test.py
Normal file
@ -0,0 +1,81 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.ops.svd."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SvdOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def _compute_usvt(self, s, u, v):
|
||||
m = u.shape[-1]
|
||||
n = v.shape[-1]
|
||||
if m <= n:
|
||||
v = v[..., :m]
|
||||
else:
|
||||
u = u[..., :n]
|
||||
|
||||
return np.matmul(u * s[..., None, :], np.swapaxes(v, -1, -2))
|
||||
|
||||
def _testSvdCorrectness(self, dtype, shape):
|
||||
np.random.seed(1)
|
||||
x_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(dtype)
|
||||
m, n = shape[-2], shape[-1]
|
||||
_, s_np, _ = np.linalg.svd(x_np)
|
||||
with self.cached_session() as sess:
|
||||
x_tf = array_ops.placeholder(dtype)
|
||||
with self.test_scope():
|
||||
s, u, v = linalg_ops.svd(x_tf, full_matrices=True)
|
||||
s_val, u_val, v_val = sess.run([s, u, v], feed_dict={x_tf: x_np})
|
||||
u_diff = np.matmul(u_val, np.swapaxes(u_val, -1, -2)) - np.eye(m)
|
||||
v_diff = np.matmul(v_val, np.swapaxes(v_val, -1, -2)) - np.eye(n)
|
||||
# Check u_val and v_val are orthogonal matrices.
|
||||
self.assertLess(np.linalg.norm(u_diff), 1e-2)
|
||||
self.assertLess(np.linalg.norm(v_diff), 1e-2)
|
||||
# Check that the singular values are correct, i.e., close to the ones from
|
||||
# numpy.lingal.svd.
|
||||
self.assertLess(np.linalg.norm(s_val - s_np), 1e-2)
|
||||
# The tolerance is set based on our tests on numpy's svd. As our tests
|
||||
# have batch dimensions and all our operations are on float32, we set the
|
||||
# tolerance a bit larger. Numpy's svd calls LAPACK's svd, which operates
|
||||
# on double precision.
|
||||
self.assertLess(
|
||||
np.linalg.norm(self._compute_usvt(s_val, u_val, v_val) - x_np), 2e-2)
|
||||
|
||||
SIZES = [1, 2, 5, 10, 32, 64]
|
||||
DTYPES = [np.float32]
|
||||
PARAMS = itertools.product(SIZES, DTYPES)
|
||||
|
||||
@parameterized.parameters(*PARAMS)
|
||||
def testSvd(self, n, dtype):
|
||||
for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10):
|
||||
self._testSvdCorrectness(dtype, batch_dims + (n, n))
|
||||
self._testSvdCorrectness(dtype, batch_dims + (2 * n, n))
|
||||
self._testSvdCorrectness(dtype, batch_dims + (n, 2 * n))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -110,6 +110,7 @@ tf_kernel_library(
|
||||
"xla_reduce_op.cc",
|
||||
"xla_select_and_scatter_op.cc",
|
||||
"xla_self_adjoint_eig_op.cc",
|
||||
"xla_svd_op.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"index_ops.h",
|
||||
@ -149,7 +150,9 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/xla/client/lib:qr",
|
||||
"//tensorflow/compiler/xla/client/lib:quantize",
|
||||
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
|
||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||
"//tensorflow/compiler/xla/client/lib:sorting",
|
||||
"//tensorflow/compiler/xla/client/lib:svd",
|
||||
"//tensorflow/core:bitwise_ops_op_lib",
|
||||
"//tensorflow/core:control_flow_ops_op_lib",
|
||||
"//tensorflow/core:data_flow_ops_op_lib",
|
||||
|
@ -35,10 +35,9 @@ class XlaConvOp : public XlaOpKernel {
|
||||
string precision_config_attr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->GetAttr("precision_config", &precision_config_attr));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
precision_config_.ParsePartialFromString(precision_config_attr),
|
||||
errors::InvalidArgument("Error parsing convolution dimension numbers"));
|
||||
OP_REQUIRES(context,
|
||||
precision_config_.ParsePartialFromString(precision_config_attr),
|
||||
errors::InvalidArgument("Error parsing precison config."));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
|
95
tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc
Normal file
95
tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc
Normal file
@ -0,0 +1,95 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/svd.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class XlaSvdOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit XlaSvdOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_iter", &max_iter_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_));
|
||||
string precision_config_attr;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetAttr("precision_config", &precision_config_attr));
|
||||
OP_REQUIRES(ctx,
|
||||
precision_config_.ParsePartialFromString(precision_config_attr),
|
||||
errors::InvalidArgument("Error parsing precison config."));
|
||||
if (precision_config_.operand_precision_size() == 0) {
|
||||
precision_config_.mutable_operand_precision()->Add(
|
||||
xla::PrecisionConfig::HIGHEST);
|
||||
}
|
||||
}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto result = xla::SVD(ctx->Input(0), max_iter_, epsilon_,
|
||||
precision_config_.operand_precision(0));
|
||||
ctx->SetOutput(0, result.d);
|
||||
ctx->SetOutput(1, result.u);
|
||||
ctx->SetOutput(2, result.v);
|
||||
}
|
||||
|
||||
private:
|
||||
int32 max_iter_;
|
||||
float epsilon_;
|
||||
xla::PrecisionConfig precision_config_;
|
||||
};
|
||||
|
||||
class SvdOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit SvdOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("compute_uv", &compute_uv_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_));
|
||||
}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape("input");
|
||||
int m = input_shape.dim_size(input_shape.dims() - 2);
|
||||
int n = input_shape.dim_size(input_shape.dims() - 1);
|
||||
// This is based on heuristics that approx log(n) sweep updates are needed.
|
||||
// Note: the heuristics provides no theoretical guarantee, max_iter=100 and
|
||||
// epsilon should be used to determine exit condition.
|
||||
int max_iter = 2 * tensorflow::Log2Ceiling(std::max(m, n));
|
||||
auto result = xla::SVD(ctx->Input(0), max_iter, 1e-6);
|
||||
ctx->SetOutput(0, result.d);
|
||||
if (compute_uv_) {
|
||||
int p = std::min(m, n);
|
||||
if (!full_matrices_) {
|
||||
if (p < m) {
|
||||
result.u = xla::SliceInMinorDims(result.u, {0, 0}, {m, p});
|
||||
}
|
||||
if (p < n) {
|
||||
result.v = xla::SliceInMinorDims(result.v, {0, 0}, {n, p});
|
||||
}
|
||||
}
|
||||
ctx->SetOutput(1, result.u);
|
||||
ctx->SetOutput(2, result.v);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool compute_uv_;
|
||||
bool full_matrices_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaSvd").TypeConstraint("T", kFloatTypes), XlaSvdOp);
|
||||
REGISTER_XLA_OP(Name("Svd").TypeConstraint("T", kFloatTypes), SvdOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -91,6 +91,40 @@ v: The column v[..., :, i] is the normalized eigenvector corresponding to the
|
||||
eigenvalue w[..., i].
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaSvd")
|
||||
.Input("a: T")
|
||||
.Attr("max_iter: int")
|
||||
.Attr("epsilon: float")
|
||||
.Attr("precision_config: string")
|
||||
.Output("s: T")
|
||||
.Output("u: T")
|
||||
.Output("v: T")
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Attr("T: numbertype")
|
||||
.Doc(R"doc(
|
||||
Computes the eigen decomposition of a batch of self-adjoint matrices
|
||||
(Note: Only real inputs are supported).
|
||||
|
||||
Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in
|
||||
tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]).
|
||||
|
||||
a: the input tensor.
|
||||
|
||||
max_iter: maximum number of sweep update, i.e., the whole lower triangular
|
||||
part or upper triangular part based on parameter lower. Heuristically, it has
|
||||
been argued that approximatly log(min (M, N)) sweeps are needed in practice
|
||||
(Ref: Golub & van Loan "Matrix Computation").
|
||||
|
||||
epsilon: the tolerance ratio.
|
||||
|
||||
precision_config: a serialized xla::PrecisionConfig proto.
|
||||
|
||||
s: Singular values. The values are sorted in reverse order of magnitude, so
|
||||
s[..., 0] is the largest value, s[..., 1] is the second largest, etc.
|
||||
u: Left singular vectors.
|
||||
v: Right singular vectors.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaConv")
|
||||
.Input("lhs: T")
|
||||
.Input("rhs: T")
|
||||
|
@ -295,6 +295,13 @@ def self_adjoint_eig(a, lower, max_iter, epsilon):
|
||||
return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon)
|
||||
|
||||
|
||||
def svd(a, max_iter, epsilon, precision_config=None):
|
||||
precision_config_proto = ""
|
||||
if precision_config:
|
||||
precision_config_proto = precision_config.SerializeToString()
|
||||
return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto)
|
||||
|
||||
|
||||
dynamic_slice = gen_xla_ops.xla_dynamic_slice
|
||||
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
|
||||
|
||||
|
@ -750,19 +750,18 @@ StatusOr<SVDResult> SortBySingularValuesAndPostProcessing(SVDResult result) {
|
||||
result.v = Mul(result.v, sign, broadcast_dims);
|
||||
|
||||
d = BroadcastInDim(d, dimensions, broadcast_dims);
|
||||
auto zero = Zero(builder, S32);
|
||||
|
||||
// As m >= n, only first m columns vectors are needed to be permuted, and the
|
||||
// rest of m - n vectors are appended after the sorting is done.
|
||||
XlaOp sort_u_result =
|
||||
Sort({-d, DynamicSliceInMinorDims(result.u, {zero, zero}, {m, n})},
|
||||
Sort({-d, SliceInMinorDims(result.u, {0, 0}, {m, n})},
|
||||
CreateScalarLtComputation(
|
||||
{shape.element_type(), shape.element_type()}, builder),
|
||||
num_dims - 1);
|
||||
|
||||
// TODO(kuny): using CreateScalarGtComputation after b/124862300 is fixed.
|
||||
XlaOp sort_v_result =
|
||||
Sort({DynamicSliceInMinorDims(-d, {zero, zero}, {n, n}), result.v},
|
||||
Sort({SliceInMinorDims(-d, {0, 0}, {n, n}), result.v},
|
||||
CreateScalarLtComputation(
|
||||
{shape.element_type(), shape.element_type()}, builder),
|
||||
num_dims - 1);
|
||||
@ -779,12 +778,10 @@ StatusOr<SVDResult> SortBySingularValuesAndPostProcessing(SVDResult result) {
|
||||
broadcast_dims);
|
||||
|
||||
// Append the rest of m - n vectors.
|
||||
result.u =
|
||||
ConcatInDim(builder,
|
||||
{GetTupleElement(sort_u_result, 1),
|
||||
DynamicSliceInMinorDims(
|
||||
result.u, {zero, ScalarLike(zero, n)}, {m, m - n})},
|
||||
num_dims - 1);
|
||||
result.u = ConcatInDim(builder,
|
||||
{GetTupleElement(sort_u_result, 1),
|
||||
SliceInMinorDims(result.u, {0, n}, {m, m})},
|
||||
num_dims - 1);
|
||||
result.u = Mul(
|
||||
result.u,
|
||||
Rsqrt(Reduce(Square(result.u), ScalarLike(d, 0.0),
|
||||
|
@ -3348,7 +3348,7 @@ cuda_py_test(
|
||||
"no_rocm", # flaky test
|
||||
"no_windows",
|
||||
],
|
||||
# TODO(kuny): Add xla_enable_strict_auto_jit = True after b/124377352 is fixed.
|
||||
# b/127344411: xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
@ -3385,7 +3385,7 @@ cuda_py_test(
|
||||
"no_oss", # b/117185141.
|
||||
"nomsan", # TODO(b/117236102): Re-enable in msan build.
|
||||
],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
# b/127344411: xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
@ -3405,7 +3405,7 @@ cuda_py_test(
|
||||
"no_windows_gpu",
|
||||
"nomsan",
|
||||
],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
# b/127344411: xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
|
@ -89,7 +89,7 @@ def _GetNormOpTest(dtype_, shape_, ord_, axis_, keep_dims_, use_static_shape_):
|
||||
if ((not is_matrix_norm and ord_ == "fro") or
|
||||
(is_matrix_norm and is_fancy_p_norm)):
|
||||
self.skipTest("Not supported by neither numpy.linalg.norm nor tf.norm")
|
||||
if ord_ == 'euclidean' or (axis_ is None and len(shape) > 2):
|
||||
if ord_ == "euclidean" or (axis_ is None and len(shape) > 2):
|
||||
self.skipTest("Not supported by numpy.linalg.norm")
|
||||
matrix = np.random.randn(*shape_).astype(dtype_)
|
||||
if dtype_ in (np.complex64, np.complex128):
|
||||
|
Loading…
Reference in New Issue
Block a user