[XLA:Python] Add Cholesky, QR, and TriangularSolve to the XLA Python API.

This allows non-TF Python clients to reuse the TensorFlow implementations of these ops (and any future improvements to be shared between users).

PiperOrigin-RevId: 225047881
This commit is contained in:
Peter Hawkins 2018-12-11 12:14:29 -08:00 committed by TensorFlower Gardener
parent 06c60fb179
commit cf9878d6a6
6 changed files with 89 additions and 3 deletions

View File

@ -66,7 +66,10 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:cholesky",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:qr",
"//tensorflow/compiler/xla/client/lib:triangular_solve",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xrt:xrt_proto",

View File

@ -24,7 +24,10 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/qr.h"
#include "tensorflow/compiler/xla/client/lib/triangular_solve.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
@ -865,6 +868,27 @@ LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys,
return xla::Sort(keys.op(), {values.op()}, dimension);
}
LocalOp LocalComputationBuilder::Cholesky(const LocalOp& a) {
return xla::Cholesky(a.op());
}
LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) {
XlaBuilder* builder = a.op().builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices));
return xla::Tuple(builder, {qr.q, qr.r});
});
}
LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a,
const LocalOp& b,
bool left_side, bool lower,
bool transpose_a,
bool conjugate_a) {
return xla::TriangularSolve(a.op(), b.op(), left_side, lower, transpose_a,
conjugate_a);
}
StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph(
const LocalOp& operand) {
TF_ASSIGN_OR_RETURN(XlaComputation computation,

View File

@ -394,6 +394,13 @@ class LocalComputationBuilder {
LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values,
int64 dimension);
LocalOp QR(const LocalOp& a, bool full_matrices);
LocalOp Cholesky(const LocalOp& a);
LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side,
bool lower, bool transpose_a, bool conjugate_a);
StatusOr<LocalComputation*> BuildConstantSubGraph(const LocalOp& operand);
#define _FORWARD(method_name, return_sig, args_sig) \

View File

@ -1144,6 +1144,9 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Imag;
%unignore xla::swig::LocalComputationBuilder::Conj;
%unignore xla::swig::LocalComputationBuilder::Complex;
%unignore xla::swig::LocalComputationBuilder::Cholesky;
%unignore xla::swig::LocalComputationBuilder::QR;
%unignore xla::swig::LocalComputationBuilder::TriangularSolve;
%unignore xla::swig::DeleteLocalComputation;
%unignore xla::swig::DestructureLocalShapedBufferTuple;
%unignore xla::swig::DestructureXrtAllocationTuple;

View File

@ -1411,6 +1411,20 @@ class ComputationBuilder(object):
"""Enqueues a key-value sort operation onto the computation."""
return self._client.SortKeyVal(keys, values, dimension)
def Cholesky(self, a):
"""Enqueues a Cholesky decomposition onto the computation."""
return self._client.Cholesky(a)
def QR(self, a, full_matrices=True):
"""Enqueues a QR decomposition onto the computation."""
return self._client.QR(a, full_matrices)
def TriangularSolve(self, a, b, left_side=False, lower=False,
transpose_a=False, conjugate_a=False):
"""Enqueues a triangular-solve operation onto the computation."""
return self._client.TriangularSolve(
a, b, left_side, lower, transpose_a, conjugate_a)
def _forward_methods_to_local_builder():
"""Forward remaining ComputationBuilder methods to the C API.

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import itertools
import threading
@ -51,9 +52,11 @@ class LocalComputationTest(unittest.TestCase):
def _ExecuteAndCompareExact(self, c, arguments=(), expected=None):
self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected)
def _ExecuteAndCompareClose(self, c, arguments=(), expected=None):
self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments,
expected)
def _ExecuteAndCompareClose(self, c, arguments=(), expected=None, rtol=1e-7,
atol=0):
self._ExecuteAndAssertWith(
functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol),
c, arguments, expected)
def NumpyArrayF32(*args, **kwargs):
@ -1057,6 +1060,38 @@ class SingleOpTest(LocalComputationTest):
self.assertTrue(np.all(lo <= result))
self.assertTrue(np.all(result < hi))
def testCholesky(self):
l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]],
dtype=np.float32)
c = self._NewComputation()
c.Cholesky(c.Constant(np.dot(l, l.T)))
self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4)
def testQR(self):
a = np.array(
[[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]],
dtype=np.float32)
c = self._NewComputation()
c.QR(c.Constant(a), full_matrices=True)
q, r = self._Execute(c, ())
np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4)
def testTriangularSolve(self):
a_vals = np.array(
[[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]],
dtype=np.float32)
b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
dtype=np.float32)
c = self._NewComputation()
c.TriangularSolve(c.Constant(a_vals), c.Constant(b_vals), left_side=False,
lower=True, transpose_a=True)
self._ExecuteAndCompareClose(c, expected=np.array([
[0.5, 0.08333334, 0.04629629, 0.03367003],
[2.5, -0.25, -0.1388889, -0.1010101],
[4.5, -0.58333331, -0.32407406, -0.23569024],
], dtype=np.float32), rtol=1e-4)
def testIsConstant(self):
c = self._NewComputation()
a = c.ConstantS32Scalar(3)