[TF:XLA] Implement XlaSvdOp

PiperOrigin-RevId: 236936547
This commit is contained in:
A. Unique TensorFlower 2019-03-05 15:49:07 -08:00 committed by TensorFlower Gardener
parent e26ab7b9e1
commit 9baeb353e1
10 changed files with 249 additions and 22 deletions

View File

@ -247,12 +247,23 @@ tf_xla_py_test(
name = "self_adjoint_eig_op_test",
size = "medium",
srcs = ["self_adjoint_eig_op_test.py"],
# TODO(kuny): remove it after b/124377352 is fixed.
disabled_backends = [
"cpu",
"gpu",
"cpu_ondemand",
tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:map_fn",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
"@absl_py//absl/testing:parameterized",
],
)
tf_xla_py_test(
name = "svd_op_test",
size = "medium",
srcs = ["svd_op_test.py"],
tags = ["optonly"],
deps = [
":xla_test",

View File

@ -0,0 +1,81 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.svd."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.platform import test
class SvdOpTest(xla_test.XLATestCase, parameterized.TestCase):
def _compute_usvt(self, s, u, v):
m = u.shape[-1]
n = v.shape[-1]
if m <= n:
v = v[..., :m]
else:
u = u[..., :n]
return np.matmul(u * s[..., None, :], np.swapaxes(v, -1, -2))
def _testSvdCorrectness(self, dtype, shape):
np.random.seed(1)
x_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(dtype)
m, n = shape[-2], shape[-1]
_, s_np, _ = np.linalg.svd(x_np)
with self.cached_session() as sess:
x_tf = array_ops.placeholder(dtype)
with self.test_scope():
s, u, v = linalg_ops.svd(x_tf, full_matrices=True)
s_val, u_val, v_val = sess.run([s, u, v], feed_dict={x_tf: x_np})
u_diff = np.matmul(u_val, np.swapaxes(u_val, -1, -2)) - np.eye(m)
v_diff = np.matmul(v_val, np.swapaxes(v_val, -1, -2)) - np.eye(n)
# Check u_val and v_val are orthogonal matrices.
self.assertLess(np.linalg.norm(u_diff), 1e-2)
self.assertLess(np.linalg.norm(v_diff), 1e-2)
# Check that the singular values are correct, i.e., close to the ones from
# numpy.lingal.svd.
self.assertLess(np.linalg.norm(s_val - s_np), 1e-2)
# The tolerance is set based on our tests on numpy's svd. As our tests
# have batch dimensions and all our operations are on float32, we set the
# tolerance a bit larger. Numpy's svd calls LAPACK's svd, which operates
# on double precision.
self.assertLess(
np.linalg.norm(self._compute_usvt(s_val, u_val, v_val) - x_np), 2e-2)
SIZES = [1, 2, 5, 10, 32, 64]
DTYPES = [np.float32]
PARAMS = itertools.product(SIZES, DTYPES)
@parameterized.parameters(*PARAMS)
def testSvd(self, n, dtype):
for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10):
self._testSvdCorrectness(dtype, batch_dims + (n, n))
self._testSvdCorrectness(dtype, batch_dims + (2 * n, n))
self._testSvdCorrectness(dtype, batch_dims + (n, 2 * n))
if __name__ == "__main__":
test.main()

View File

@ -110,6 +110,7 @@ tf_kernel_library(
"xla_reduce_op.cc",
"xla_select_and_scatter_op.cc",
"xla_self_adjoint_eig_op.cc",
"xla_svd_op.cc",
],
hdrs = [
"index_ops.h",
@ -149,7 +150,9 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client/lib:qr",
"//tensorflow/compiler/xla/client/lib:quantize",
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/compiler/xla/client/lib:svd",
"//tensorflow/core:bitwise_ops_op_lib",
"//tensorflow/core:control_flow_ops_op_lib",
"//tensorflow/core:data_flow_ops_op_lib",

View File

@ -35,10 +35,9 @@ class XlaConvOp : public XlaOpKernel {
string precision_config_attr;
OP_REQUIRES_OK(
context, context->GetAttr("precision_config", &precision_config_attr));
OP_REQUIRES(
context,
precision_config_.ParsePartialFromString(precision_config_attr),
errors::InvalidArgument("Error parsing convolution dimension numbers"));
OP_REQUIRES(context,
precision_config_.ParsePartialFromString(precision_config_attr),
errors::InvalidArgument("Error parsing precison config."));
}
void Compile(XlaOpKernelContext* context) override {

View File

@ -0,0 +1,95 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/lib/svd.h"
namespace tensorflow {
namespace {
class XlaSvdOp : public XlaOpKernel {
public:
explicit XlaSvdOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_iter", &max_iter_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_));
string precision_config_attr;
OP_REQUIRES_OK(ctx,
ctx->GetAttr("precision_config", &precision_config_attr));
OP_REQUIRES(ctx,
precision_config_.ParsePartialFromString(precision_config_attr),
errors::InvalidArgument("Error parsing precison config."));
if (precision_config_.operand_precision_size() == 0) {
precision_config_.mutable_operand_precision()->Add(
xla::PrecisionConfig::HIGHEST);
}
}
void Compile(XlaOpKernelContext* ctx) override {
auto result = xla::SVD(ctx->Input(0), max_iter_, epsilon_,
precision_config_.operand_precision(0));
ctx->SetOutput(0, result.d);
ctx->SetOutput(1, result.u);
ctx->SetOutput(2, result.v);
}
private:
int32 max_iter_;
float epsilon_;
xla::PrecisionConfig precision_config_;
};
class SvdOp : public XlaOpKernel {
public:
explicit SvdOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("compute_uv", &compute_uv_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_));
}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape("input");
int m = input_shape.dim_size(input_shape.dims() - 2);
int n = input_shape.dim_size(input_shape.dims() - 1);
// This is based on heuristics that approx log(n) sweep updates are needed.
// Note: the heuristics provides no theoretical guarantee, max_iter=100 and
// epsilon should be used to determine exit condition.
int max_iter = 2 * tensorflow::Log2Ceiling(std::max(m, n));
auto result = xla::SVD(ctx->Input(0), max_iter, 1e-6);
ctx->SetOutput(0, result.d);
if (compute_uv_) {
int p = std::min(m, n);
if (!full_matrices_) {
if (p < m) {
result.u = xla::SliceInMinorDims(result.u, {0, 0}, {m, p});
}
if (p < n) {
result.v = xla::SliceInMinorDims(result.v, {0, 0}, {n, p});
}
}
ctx->SetOutput(1, result.u);
ctx->SetOutput(2, result.v);
}
}
private:
bool compute_uv_;
bool full_matrices_;
};
REGISTER_XLA_OP(Name("XlaSvd").TypeConstraint("T", kFloatTypes), XlaSvdOp);
REGISTER_XLA_OP(Name("Svd").TypeConstraint("T", kFloatTypes), SvdOp);
} // namespace
} // namespace tensorflow

View File

@ -91,6 +91,40 @@ v: The column v[..., :, i] is the normalized eigenvector corresponding to the
eigenvalue w[..., i].
)doc");
REGISTER_OP("XlaSvd")
.Input("a: T")
.Attr("max_iter: int")
.Attr("epsilon: float")
.Attr("precision_config: string")
.Output("s: T")
.Output("u: T")
.Output("v: T")
.SetShapeFn(shape_inference::UnknownShape)
.Attr("T: numbertype")
.Doc(R"doc(
Computes the eigen decomposition of a batch of self-adjoint matrices
(Note: Only real inputs are supported).
Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in
tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]).
a: the input tensor.
max_iter: maximum number of sweep update, i.e., the whole lower triangular
part or upper triangular part based on parameter lower. Heuristically, it has
been argued that approximatly log(min (M, N)) sweeps are needed in practice
(Ref: Golub & van Loan "Matrix Computation").
epsilon: the tolerance ratio.
precision_config: a serialized xla::PrecisionConfig proto.
s: Singular values. The values are sorted in reverse order of magnitude, so
s[..., 0] is the largest value, s[..., 1] is the second largest, etc.
u: Left singular vectors.
v: Right singular vectors.
)doc");
REGISTER_OP("XlaConv")
.Input("lhs: T")
.Input("rhs: T")

View File

@ -295,6 +295,13 @@ def self_adjoint_eig(a, lower, max_iter, epsilon):
return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon)
def svd(a, max_iter, epsilon, precision_config=None):
precision_config_proto = ""
if precision_config:
precision_config_proto = precision_config.SerializeToString()
return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto)
dynamic_slice = gen_xla_ops.xla_dynamic_slice
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice

View File

@ -750,19 +750,18 @@ StatusOr<SVDResult> SortBySingularValuesAndPostProcessing(SVDResult result) {
result.v = Mul(result.v, sign, broadcast_dims);
d = BroadcastInDim(d, dimensions, broadcast_dims);
auto zero = Zero(builder, S32);
// As m >= n, only first m columns vectors are needed to be permuted, and the
// rest of m - n vectors are appended after the sorting is done.
XlaOp sort_u_result =
Sort({-d, DynamicSliceInMinorDims(result.u, {zero, zero}, {m, n})},
Sort({-d, SliceInMinorDims(result.u, {0, 0}, {m, n})},
CreateScalarLtComputation(
{shape.element_type(), shape.element_type()}, builder),
num_dims - 1);
// TODO(kuny): using CreateScalarGtComputation after b/124862300 is fixed.
XlaOp sort_v_result =
Sort({DynamicSliceInMinorDims(-d, {zero, zero}, {n, n}), result.v},
Sort({SliceInMinorDims(-d, {0, 0}, {n, n}), result.v},
CreateScalarLtComputation(
{shape.element_type(), shape.element_type()}, builder),
num_dims - 1);
@ -779,12 +778,10 @@ StatusOr<SVDResult> SortBySingularValuesAndPostProcessing(SVDResult result) {
broadcast_dims);
// Append the rest of m - n vectors.
result.u =
ConcatInDim(builder,
{GetTupleElement(sort_u_result, 1),
DynamicSliceInMinorDims(
result.u, {zero, ScalarLike(zero, n)}, {m, m - n})},
num_dims - 1);
result.u = ConcatInDim(builder,
{GetTupleElement(sort_u_result, 1),
SliceInMinorDims(result.u, {0, n}, {m, m})},
num_dims - 1);
result.u = Mul(
result.u,
Rsqrt(Reduce(Square(result.u), ScalarLike(d, 0.0),

View File

@ -3348,7 +3348,7 @@ cuda_py_test(
"no_rocm", # flaky test
"no_windows",
],
# TODO(kuny): Add xla_enable_strict_auto_jit = True after b/124377352 is fixed.
# b/127344411: xla_enable_strict_auto_jit = True,
)
cuda_py_test(
@ -3385,7 +3385,7 @@ cuda_py_test(
"no_oss", # b/117185141.
"nomsan", # TODO(b/117236102): Re-enable in msan build.
],
xla_enable_strict_auto_jit = True,
# b/127344411: xla_enable_strict_auto_jit = True,
)
cuda_py_test(
@ -3405,7 +3405,7 @@ cuda_py_test(
"no_windows_gpu",
"nomsan",
],
xla_enable_strict_auto_jit = True,
# b/127344411: xla_enable_strict_auto_jit = True,
)
cuda_py_test(

View File

@ -89,7 +89,7 @@ def _GetNormOpTest(dtype_, shape_, ord_, axis_, keep_dims_, use_static_shape_):
if ((not is_matrix_norm and ord_ == "fro") or
(is_matrix_norm and is_fancy_p_norm)):
self.skipTest("Not supported by neither numpy.linalg.norm nor tf.norm")
if ord_ == 'euclidean' or (axis_ is None and len(shape) > 2):
if ord_ == "euclidean" or (axis_ is None and len(shape) > 2):
self.skipTest("Not supported by numpy.linalg.norm")
matrix = np.random.randn(*shape_).astype(dtype_)
if dtype_ in (np.complex64, np.complex128):