[XLA:Python] Add Cholesky, QR, and TriangularSolve to the XLA Python API.
This allows non-TF Python clients to reuse the TensorFlow implementations of these ops (and any future improvements to be shared between users). PiperOrigin-RevId: 225047881
This commit is contained in:
parent
06c60fb179
commit
cf9878d6a6
@ -66,7 +66,10 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:cholesky",
|
||||
"//tensorflow/compiler/xla/client/lib:math",
|
||||
"//tensorflow/compiler/xla/client/lib:qr",
|
||||
"//tensorflow/compiler/xla/client/lib:triangular_solve",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xrt:xrt_proto",
|
||||
|
@ -24,7 +24,10 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/qr.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/triangular_solve.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
@ -865,6 +868,27 @@ LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys,
|
||||
return xla::Sort(keys.op(), {values.op()}, dimension);
|
||||
}
|
||||
|
||||
LocalOp LocalComputationBuilder::Cholesky(const LocalOp& a) {
|
||||
return xla::Cholesky(a.op());
|
||||
}
|
||||
|
||||
LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) {
|
||||
XlaBuilder* builder = a.op().builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices));
|
||||
return xla::Tuple(builder, {qr.q, qr.r});
|
||||
});
|
||||
}
|
||||
|
||||
LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a,
|
||||
const LocalOp& b,
|
||||
bool left_side, bool lower,
|
||||
bool transpose_a,
|
||||
bool conjugate_a) {
|
||||
return xla::TriangularSolve(a.op(), b.op(), left_side, lower, transpose_a,
|
||||
conjugate_a);
|
||||
}
|
||||
|
||||
StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph(
|
||||
const LocalOp& operand) {
|
||||
TF_ASSIGN_OR_RETURN(XlaComputation computation,
|
||||
|
@ -394,6 +394,13 @@ class LocalComputationBuilder {
|
||||
LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values,
|
||||
int64 dimension);
|
||||
|
||||
LocalOp QR(const LocalOp& a, bool full_matrices);
|
||||
|
||||
LocalOp Cholesky(const LocalOp& a);
|
||||
|
||||
LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side,
|
||||
bool lower, bool transpose_a, bool conjugate_a);
|
||||
|
||||
StatusOr<LocalComputation*> BuildConstantSubGraph(const LocalOp& operand);
|
||||
|
||||
#define _FORWARD(method_name, return_sig, args_sig) \
|
||||
|
@ -1144,6 +1144,9 @@ tensorflow::ImportNumpy();
|
||||
%unignore xla::swig::LocalComputationBuilder::Imag;
|
||||
%unignore xla::swig::LocalComputationBuilder::Conj;
|
||||
%unignore xla::swig::LocalComputationBuilder::Complex;
|
||||
%unignore xla::swig::LocalComputationBuilder::Cholesky;
|
||||
%unignore xla::swig::LocalComputationBuilder::QR;
|
||||
%unignore xla::swig::LocalComputationBuilder::TriangularSolve;
|
||||
%unignore xla::swig::DeleteLocalComputation;
|
||||
%unignore xla::swig::DestructureLocalShapedBufferTuple;
|
||||
%unignore xla::swig::DestructureXrtAllocationTuple;
|
||||
|
@ -1411,6 +1411,20 @@ class ComputationBuilder(object):
|
||||
"""Enqueues a key-value sort operation onto the computation."""
|
||||
return self._client.SortKeyVal(keys, values, dimension)
|
||||
|
||||
def Cholesky(self, a):
|
||||
"""Enqueues a Cholesky decomposition onto the computation."""
|
||||
return self._client.Cholesky(a)
|
||||
|
||||
def QR(self, a, full_matrices=True):
|
||||
"""Enqueues a QR decomposition onto the computation."""
|
||||
return self._client.QR(a, full_matrices)
|
||||
|
||||
def TriangularSolve(self, a, b, left_side=False, lower=False,
|
||||
transpose_a=False, conjugate_a=False):
|
||||
"""Enqueues a triangular-solve operation onto the computation."""
|
||||
return self._client.TriangularSolve(
|
||||
a, b, left_side, lower, transpose_a, conjugate_a)
|
||||
|
||||
|
||||
def _forward_methods_to_local_builder():
|
||||
"""Forward remaining ComputationBuilder methods to the C API.
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
import threading
|
||||
|
||||
@ -51,9 +52,11 @@ class LocalComputationTest(unittest.TestCase):
|
||||
def _ExecuteAndCompareExact(self, c, arguments=(), expected=None):
|
||||
self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected)
|
||||
|
||||
def _ExecuteAndCompareClose(self, c, arguments=(), expected=None):
|
||||
self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments,
|
||||
expected)
|
||||
def _ExecuteAndCompareClose(self, c, arguments=(), expected=None, rtol=1e-7,
|
||||
atol=0):
|
||||
self._ExecuteAndAssertWith(
|
||||
functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol),
|
||||
c, arguments, expected)
|
||||
|
||||
|
||||
def NumpyArrayF32(*args, **kwargs):
|
||||
@ -1057,6 +1060,38 @@ class SingleOpTest(LocalComputationTest):
|
||||
self.assertTrue(np.all(lo <= result))
|
||||
self.assertTrue(np.all(result < hi))
|
||||
|
||||
def testCholesky(self):
|
||||
l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]],
|
||||
dtype=np.float32)
|
||||
c = self._NewComputation()
|
||||
c.Cholesky(c.Constant(np.dot(l, l.T)))
|
||||
self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4)
|
||||
|
||||
def testQR(self):
|
||||
a = np.array(
|
||||
[[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]],
|
||||
dtype=np.float32)
|
||||
c = self._NewComputation()
|
||||
c.QR(c.Constant(a), full_matrices=True)
|
||||
q, r = self._Execute(c, ())
|
||||
np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4)
|
||||
|
||||
def testTriangularSolve(self):
|
||||
a_vals = np.array(
|
||||
[[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]],
|
||||
dtype=np.float32)
|
||||
b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||
dtype=np.float32)
|
||||
|
||||
c = self._NewComputation()
|
||||
c.TriangularSolve(c.Constant(a_vals), c.Constant(b_vals), left_side=False,
|
||||
lower=True, transpose_a=True)
|
||||
self._ExecuteAndCompareClose(c, expected=np.array([
|
||||
[0.5, 0.08333334, 0.04629629, 0.03367003],
|
||||
[2.5, -0.25, -0.1388889, -0.1010101],
|
||||
[4.5, -0.58333331, -0.32407406, -0.23569024],
|
||||
], dtype=np.float32), rtol=1e-4)
|
||||
|
||||
def testIsConstant(self):
|
||||
c = self._NewComputation()
|
||||
a = c.ConstantS32Scalar(3)
|
||||
|
Loading…
Reference in New Issue
Block a user