Merge branch 'master' into upstream-staging-norocm-tag-1
This commit is contained in:
commit
a5eb965967
@ -26,14 +26,28 @@ import sys as _sys
|
|||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
|
# Make sure directory containing top level submodules is in
|
||||||
|
# the __path__ so that "from tensorflow.foo import bar" works.
|
||||||
|
# We're using bitwise, but there's nothing special about that.
|
||||||
|
_API_MODULE = bitwise # pylint: disable=undefined-variable
|
||||||
|
_current_module = _sys.modules[__name__]
|
||||||
|
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
|
||||||
|
if not hasattr(_current_module, '__path__'):
|
||||||
|
__path__ = [_tf_api_dir]
|
||||||
|
elif _tf_api_dir not in __path__:
|
||||||
|
__path__.append(_tf_api_dir)
|
||||||
|
|
||||||
# pylint: disable=g-bad-import-order
|
# pylint: disable=g-bad-import-order
|
||||||
from tensorflow.python.tools import component_api_helper as _component_api_helper
|
from tensorflow.python.tools import component_api_helper as _component_api_helper
|
||||||
|
_component_api_helper.package_hook(
|
||||||
|
parent_package_str=__name__,
|
||||||
|
child_package_str=('tensorboard.summary._tf.summary'),
|
||||||
|
error_msg="Limited tf.summary API due to missing TensorBoard installation")
|
||||||
_component_api_helper.package_hook(
|
_component_api_helper.package_hook(
|
||||||
parent_package_str=__name__,
|
parent_package_str=__name__,
|
||||||
child_package_str=(
|
child_package_str=(
|
||||||
'tensorflow_estimator.python.estimator.api._v2.estimator'))
|
'tensorflow_estimator.python.estimator.api._v2.estimator'))
|
||||||
|
|
||||||
_current_module = _sys.modules[__name__]
|
|
||||||
if not hasattr(_current_module, 'estimator'):
|
if not hasattr(_current_module, 'estimator'):
|
||||||
_component_api_helper.package_hook(
|
_component_api_helper.package_hook(
|
||||||
parent_package_str=__name__,
|
parent_package_str=__name__,
|
||||||
@ -42,14 +56,6 @@ if not hasattr(_current_module, 'estimator'):
|
|||||||
_component_api_helper.package_hook(
|
_component_api_helper.package_hook(
|
||||||
parent_package_str=__name__,
|
parent_package_str=__name__,
|
||||||
child_package_str=('tensorflow.python.keras.api._v2.keras'))
|
child_package_str=('tensorflow.python.keras.api._v2.keras'))
|
||||||
# Make sure directory containing top level submodules is in
|
|
||||||
# the __path__ so that "from tensorflow.foo import bar" works.
|
|
||||||
# We're using bitwise, but there's nothing special about that.
|
|
||||||
_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable
|
|
||||||
if not hasattr(_current_module, '__path__'):
|
|
||||||
__path__ = [_tf_api_dir]
|
|
||||||
elif _tf_api_dir not in __path__:
|
|
||||||
__path__.append(_tf_api_dir)
|
|
||||||
|
|
||||||
# Enable TF2 behaviors
|
# Enable TF2 behaviors
|
||||||
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
|
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
|
||||||
|
@ -70,7 +70,7 @@ _API_MODULE = app # pylint: disable=undefined-variable
|
|||||||
|
|
||||||
# Make sure directory containing top level submodules is in
|
# Make sure directory containing top level submodules is in
|
||||||
# the __path__ so that "from tensorflow.foo import bar" works.
|
# the __path__ so that "from tensorflow.foo import bar" works.
|
||||||
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) # pylint: disable=undefined-variable
|
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
|
||||||
if not hasattr(_current_module, '__path__'):
|
if not hasattr(_current_module, '__path__'):
|
||||||
__path__ = [_tf_api_dir]
|
__path__ = [_tf_api_dir]
|
||||||
elif _tf_api_dir not in __path__:
|
elif _tf_api_dir not in __path__:
|
||||||
|
@ -150,6 +150,7 @@ cc_library_with_android_deps(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:ops",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -586,6 +587,25 @@ tf_gen_op_wrappers_cc(
|
|||||||
pkg = "//tensorflow/core",
|
pkg = "//tensorflow/core",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_gen_op_wrappers_cc(
|
||||||
|
name = "tpu_ops",
|
||||||
|
include_internal_ops = 1,
|
||||||
|
op_lib_names = [
|
||||||
|
"tpu_configuration_ops",
|
||||||
|
"tpu_cross_replica_ops",
|
||||||
|
"tpu_embedding_ops",
|
||||||
|
"tpu_functional_ops",
|
||||||
|
"tpu_heartbeat_ops",
|
||||||
|
"tpu_host_compute_ops",
|
||||||
|
"tpu_infeed_ops",
|
||||||
|
"tpu_outfeed_ops",
|
||||||
|
"tpu_ordinal_selector_ops",
|
||||||
|
"tpu_replication_ops",
|
||||||
|
],
|
||||||
|
pkg = "//tensorflow/core",
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library_with_android_deps(
|
cc_library_with_android_deps(
|
||||||
name = "cc_op_gen_main",
|
name = "cc_op_gen_main",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -81,6 +81,7 @@ cc_library(
|
|||||||
] + if_not_mobile([
|
] + if_not_mobile([
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:ops",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:tensorflow",
|
"//tensorflow/core:tensorflow",
|
||||||
]) + if_android([
|
]) + if_android([
|
||||||
|
@ -22,11 +22,16 @@ import os as _os
|
|||||||
import sys as _sys
|
import sys as _sys
|
||||||
|
|
||||||
# pylint: disable=g-bad-import-order
|
# pylint: disable=g-bad-import-order
|
||||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
|
||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
from tensorflow.python.tools import component_api_helper as _component_api_helper
|
from tensorflow.python.tools import component_api_helper as _component_api_helper
|
||||||
|
_component_api_helper.package_hook(
|
||||||
|
parent_package_str=__name__,
|
||||||
|
child_package_str=('tensorboard.summary._tf.summary'),
|
||||||
|
error_msg=(
|
||||||
|
"Limited tf.compat.v2.summary API due to missing TensorBoard "
|
||||||
|
"installation"))
|
||||||
_component_api_helper.package_hook(
|
_component_api_helper.package_hook(
|
||||||
parent_package_str=__name__,
|
parent_package_str=__name__,
|
||||||
child_package_str=(
|
child_package_str=(
|
||||||
|
@ -307,22 +307,6 @@ REGISTER_OP("XlaHostCompute")
|
|||||||
.Attr("shapes: list(shape) >= 0")
|
.Attr("shapes: list(shape) >= 0")
|
||||||
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
|
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
|
||||||
|
|
||||||
REGISTER_OP("_XlaSendFromHost")
|
|
||||||
.Input("inputs: Tinputs")
|
|
||||||
.Input("dynamic_key: string")
|
|
||||||
.Attr("Tinputs: list(type) >= 0")
|
|
||||||
.Attr("key: string")
|
|
||||||
.Attr("device_ordinal: int")
|
|
||||||
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
|
|
||||||
|
|
||||||
REGISTER_OP("_XlaRecvAtHost")
|
|
||||||
.Input("dynamic_key: string")
|
|
||||||
.Output("outputs: Toutputs")
|
|
||||||
.Attr("Toutputs: list(type) >= 0")
|
|
||||||
.Attr("key: string")
|
|
||||||
.Attr("device_ordinal: int")
|
|
||||||
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
|
|
||||||
|
|
||||||
REGISTER_OP("InputTest")
|
REGISTER_OP("InputTest")
|
||||||
.Output("o: float")
|
.Output("o: float")
|
||||||
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
|
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
|
||||||
|
@ -200,7 +200,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
|||||||
auto serialized = absl::make_unique<char[]>(size);
|
auto serialized = absl::make_unique<char[]>(size);
|
||||||
TF_RET_CHECK(SerializeToBufferDeterministic(gdef, serialized.get(), size));
|
TF_RET_CHECK(SerializeToBufferDeterministic(gdef, serialized.get(), size));
|
||||||
uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size));
|
uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size));
|
||||||
LOG(INFO) << "Subgraph fingerprint:" << fingerprint;
|
VLOG(1) << "Subgraph fingerprint:" << fingerprint;
|
||||||
call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint));
|
call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -250,6 +250,29 @@ tf_xla_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_xla_py_test(
|
||||||
|
name = "self_adjoint_eig_op_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["self_adjoint_eig_op_test.py"],
|
||||||
|
# TODO(kuny): remove it after b/124377352 is fixed.
|
||||||
|
disabled_backends = [
|
||||||
|
"cpu",
|
||||||
|
"gpu",
|
||||||
|
"cpu_ondemand",
|
||||||
|
],
|
||||||
|
tags = ["optonly"],
|
||||||
|
deps = [
|
||||||
|
":xla_test",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:framework",
|
||||||
|
"//tensorflow/python:map_fn",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python:training",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_xla_py_test(
|
tf_xla_py_test(
|
||||||
name = "matrix_triangular_solve_op_test",
|
name = "matrix_triangular_solve_op_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
62
tensorflow/compiler/tests/self_adjoint_eig_op_test.py
Normal file
62
tensorflow/compiler/tests/self_adjoint_eig_op_test.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tensorflow.ops.self_adjoint_eig."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.compiler.tests import xla_test
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import linalg_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAdjointEigOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def _test(self, dtype, shape):
|
||||||
|
np.random.seed(1)
|
||||||
|
x_np = np.random.uniform(
|
||||||
|
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
|
||||||
|
x_np = x_np + np.swapaxes(x_np, -1, -2)
|
||||||
|
n = shape[-1]
|
||||||
|
|
||||||
|
e_np, _ = np.linalg.eigh(x_np)
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
x_tf = array_ops.placeholder(dtype)
|
||||||
|
with self.test_scope():
|
||||||
|
e, v = linalg_ops.self_adjoint_eig(x_tf)
|
||||||
|
e_val, v_val = sess.run([e, v], feed_dict={x_tf: x_np})
|
||||||
|
|
||||||
|
v_diff = np.matmul(v_val, np.swapaxes(v_val, -1, -2)) - np.eye(n)
|
||||||
|
self.assertAlmostEqual(np.mean(v_diff**2), 0.0, delta=1e-6)
|
||||||
|
self.assertAlmostEqual(np.mean((e_val - e_np)**2), 0.0, delta=1e-6)
|
||||||
|
|
||||||
|
SIZES = [1, 2, 5, 10, 32]
|
||||||
|
DTYPES = [np.float32]
|
||||||
|
PARAMS = itertools.product(SIZES, DTYPES)
|
||||||
|
|
||||||
|
@parameterized.parameters(*PARAMS)
|
||||||
|
def testSelfAdjointEig(self, n, dtype):
|
||||||
|
for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10):
|
||||||
|
self._test(dtype, batch_dims + (n, n))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -102,7 +102,7 @@ class ListOpsTest(xla_test.XLATestCase):
|
|||||||
_, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
_, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
"Set the max number of elements"):
|
"Set the max number of elements"):
|
||||||
self.assertEqual(sess.run(e), 1.0 * np.ones((7, 15)))
|
self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15)))
|
||||||
|
|
||||||
def testEmptyTensorListMax(self):
|
def testEmptyTensorListMax(self):
|
||||||
with self.cached_session() as sess, self.test_scope():
|
with self.cached_session() as sess, self.test_scope():
|
||||||
@ -136,6 +136,17 @@ class ListOpsTest(xla_test.XLATestCase):
|
|||||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
||||||
self.assertAllEqual(t, [3.0, 2.0])
|
self.assertAllEqual(t, [3.0, 2.0])
|
||||||
|
|
||||||
|
def testSetDoesNotUpdatePushIndex(self):
|
||||||
|
with self.cached_session(), self.test_scope():
|
||||||
|
l = list_ops.empty_tensor_list(
|
||||||
|
element_shape=[], element_dtype=dtypes.float32, max_num_elements=2)
|
||||||
|
# SetItem should not change the push index.
|
||||||
|
l = list_ops.tensor_list_set_item(l, 1, 3.)
|
||||||
|
l = list_ops.tensor_list_push_back(l, 5.)
|
||||||
|
l = list_ops.tensor_list_push_back(l, 7.)
|
||||||
|
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
||||||
|
self.assertAllEqual(t, [5., 7.])
|
||||||
|
|
||||||
def testGetSetReserved(self):
|
def testGetSetReserved(self):
|
||||||
with self.cached_session(), self.test_scope():
|
with self.cached_session(), self.test_scope():
|
||||||
l = list_ops.tensor_list_reserve(
|
l = list_ops.tensor_list_reserve(
|
||||||
@ -146,6 +157,25 @@ class ListOpsTest(xla_test.XLATestCase):
|
|||||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
||||||
self.assertAllEqual(t, [3.0, 0.0])
|
self.assertAllEqual(t, [3.0, 0.0])
|
||||||
|
|
||||||
|
def testSetStackReservedUnknownElementShape(self):
|
||||||
|
with self.cached_session(), self.test_scope():
|
||||||
|
l = list_ops.tensor_list_reserve(
|
||||||
|
element_dtype=dtypes.float32, element_shape=None, num_elements=2)
|
||||||
|
l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0])
|
||||||
|
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
||||||
|
self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]])
|
||||||
|
|
||||||
|
def testPushInEmptyListWithUnknownElementShape(self):
|
||||||
|
with self.cached_session(), self.test_scope():
|
||||||
|
l = list_ops.empty_tensor_list(
|
||||||
|
element_dtype=dtypes.float32, element_shape=None, max_num_elements=2)
|
||||||
|
l = list_ops.tensor_list_push_back(l, [3.0, 4.0])
|
||||||
|
# Pushing an element with a different shape should raise an error.
|
||||||
|
with self.assertRaisesRegexp(errors.InvalidArgumentError, "Shape"):
|
||||||
|
l = list_ops.tensor_list_push_back(l, 5.)
|
||||||
|
self.evaluate(
|
||||||
|
list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
|
||||||
|
|
||||||
def testGetSetReservedNonScalar(self):
|
def testGetSetReservedNonScalar(self):
|
||||||
with self.cached_session() as sess, self.test_scope():
|
with self.cached_session() as sess, self.test_scope():
|
||||||
l = list_ops.tensor_list_reserve(
|
l = list_ops.tensor_list_reserve(
|
||||||
|
@ -72,6 +72,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
|||||||
output = op(pinp)
|
output = op(pinp)
|
||||||
result = session.run(output, {pinp: inp})
|
result = session.run(output, {pinp: inp})
|
||||||
if equality_test is None:
|
if equality_test is None:
|
||||||
|
self.assertEqual(output.dtype, expected.dtype)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
|
result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
|
||||||
else:
|
else:
|
||||||
@ -260,7 +261,8 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
|||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
math_ops.log1p,
|
math_ops.log1p,
|
||||||
np.array([[1e-14, 1e-15, 0.6]], dtype=dtype),
|
np.array([[1e-14, 1e-15, 0.6]], dtype=dtype),
|
||||||
expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)),
|
expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]],
|
||||||
|
dtype=dtype)).astype(dtype),
|
||||||
rtol=1e-4,
|
rtol=1e-4,
|
||||||
atol=1e-6)
|
atol=1e-6)
|
||||||
|
|
||||||
@ -710,7 +712,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
|||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
math_ops.abs,
|
math_ops.abs,
|
||||||
np.array([[2, -1]], dtype=dtype),
|
np.array([[2, -1]], dtype=dtype),
|
||||||
expected=np.array([[2, 1]], dtype=dtype))
|
expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype))
|
||||||
|
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
math_ops.negative,
|
math_ops.negative,
|
||||||
@ -880,6 +882,17 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
|||||||
np.array([[-1], [1], [4]], dtype=dtype),
|
np.array([[-1], [1], [4]], dtype=dtype),
|
||||||
expected=np.int32(3))
|
expected=np.int32(3))
|
||||||
|
|
||||||
|
def testSizeWithInt64OutType(self):
|
||||||
|
|
||||||
|
def size_op(x):
|
||||||
|
return array_ops.size_internal(x, optimize=False, out_type=np.int64)
|
||||||
|
|
||||||
|
for dtype in self.numeric_types:
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
size_op,
|
||||||
|
np.array([[-1], [1], [4]], dtype=dtype),
|
||||||
|
expected=np.int64(3))
|
||||||
|
|
||||||
def testUnpack(self):
|
def testUnpack(self):
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
array_ops.unstack,
|
array_ops.unstack,
|
||||||
@ -989,7 +1002,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
|||||||
def _assertSoftplusMatchesExpected(self, features, dtype):
|
def _assertSoftplusMatchesExpected(self, features, dtype):
|
||||||
features = np.array(features, dtype=dtype)
|
features = np.array(features, dtype=dtype)
|
||||||
zero = np.asarray(0).astype(dtype)
|
zero = np.asarray(0).astype(dtype)
|
||||||
expected = np.logaddexp(zero, features)
|
expected = np.logaddexp(zero, features).astype(dtype)
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6)
|
nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6)
|
||||||
|
|
||||||
|
@ -171,13 +171,11 @@ tf_cuda_library(
|
|||||||
name = "trt_resources",
|
name = "trt_resources",
|
||||||
srcs = [
|
srcs = [
|
||||||
"utils/trt_int8_calibrator.cc",
|
"utils/trt_int8_calibrator.cc",
|
||||||
"utils/trt_resource_manager.cc",
|
|
||||||
"utils/trt_resources.cc",
|
"utils/trt_resources.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"utils/trt_int8_calibrator.h",
|
"utils/trt_int8_calibrator.h",
|
||||||
"utils/trt_lru_cache.h",
|
"utils/trt_lru_cache.h",
|
||||||
"utils/trt_resource_manager.h",
|
|
||||||
"utils/trt_resources.h",
|
"utils/trt_resources.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
@ -266,7 +264,6 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:framework_lite",
|
"//tensorflow/core:framework_lite",
|
||||||
"//tensorflow/core:gpu_runtime",
|
"//tensorflow/core:gpu_runtime",
|
||||||
"//tensorflow/core:graph",
|
"//tensorflow/core:graph",
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:devices",
|
"//tensorflow/core/grappler:devices",
|
||||||
@ -362,11 +359,12 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cuda_cc_test(
|
||||||
name = "segment_test",
|
name = "segment_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["segment/segment_test.cc"],
|
srcs = ["segment/segment_test.cc"],
|
||||||
tags = [
|
tags = [
|
||||||
|
"no_cuda_on_cpu_tap",
|
||||||
"no_windows",
|
"no_windows",
|
||||||
"nomac",
|
"nomac",
|
||||||
],
|
],
|
||||||
@ -432,7 +430,7 @@ cc_library(
|
|||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -441,7 +439,7 @@ cc_library(
|
|||||||
srcs = ["utils/test_utils.cc"],
|
srcs = ["utils/test_utils.cc"],
|
||||||
hdrs = ["utils/test_utils.h"],
|
hdrs = ["utils/test_utils.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
"@com_googlesource_code_re2//:re2",
|
"@com_googlesource_code_re2//:re2",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -30,7 +30,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
|
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
|
#include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h"
|
#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h"
|
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
|
||||||
@ -106,6 +105,7 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) {
|
|||||||
"ExpandDims",
|
"ExpandDims",
|
||||||
"FusedBatchNorm",
|
"FusedBatchNorm",
|
||||||
"FusedBatchNormV2",
|
"FusedBatchNormV2",
|
||||||
|
"GatherV2",
|
||||||
"Identity",
|
"Identity",
|
||||||
"LeakyRelu",
|
"LeakyRelu",
|
||||||
"Log",
|
"Log",
|
||||||
@ -190,55 +190,6 @@ tensorflow::Status BuildNodeMap(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Function to get calibration from ResourceMgr and put them into nodedef.
|
|
||||||
tensorflow::Status ConvertCalibGraphToInferGraph(
|
|
||||||
const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph,
|
|
||||||
bool is_dyn_op) {
|
|
||||||
LOG(INFO) << "Starting Calib Conversion";
|
|
||||||
*infer_graph = graph_def;
|
|
||||||
auto trt_rm = TRTResourceManager::instance();
|
|
||||||
auto calib_rm = trt_rm->getManager("TRTCalibration");
|
|
||||||
int num_nodes = infer_graph->node_size();
|
|
||||||
if (!is_dyn_op) {
|
|
||||||
LOG(WARNING) << "Construction of static int8 engine is not implemented "
|
|
||||||
"yet!. Dynamic engine will be constructed";
|
|
||||||
}
|
|
||||||
for (int i = 0; i < num_nodes; ++i) {
|
|
||||||
auto n = infer_graph->mutable_node(i);
|
|
||||||
if (n->op() == "TRTEngineOp") {
|
|
||||||
VLOG(1) << "Processing " << n->name();
|
|
||||||
const string& container_name = n->attr().at("segment_funcdef_name").s();
|
|
||||||
TRTCalibrationResource* cres = nullptr;
|
|
||||||
auto status = calib_rm->Lookup(container_name, "Calibrator", &cres);
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << "Could not get Calibration information. Did you run with "
|
|
||||||
"calibration data?";
|
|
||||||
return tensorflow::errors::FailedPrecondition(
|
|
||||||
"Need to run graph with calibration data first!");
|
|
||||||
}
|
|
||||||
tensorflow::core::ScopedUnref calib_sc(cres);
|
|
||||||
if (cres->calibrator_) {
|
|
||||||
cres->calibrator_->waitAndSetDone();
|
|
||||||
cres->thr_->join();
|
|
||||||
const auto& calibration_table =
|
|
||||||
cres->calibrator_->getCalibrationTableAsString();
|
|
||||||
if (calibration_table.empty()) {
|
|
||||||
LOG(ERROR) << "Calibration table is empty";
|
|
||||||
return tensorflow::errors::Unknown(
|
|
||||||
"Calibration table is missing. This shouldn't have happened!");
|
|
||||||
}
|
|
||||||
n->mutable_attr()->at("calibration_data").set_s(calibration_table);
|
|
||||||
} else {
|
|
||||||
LOG(ERROR) << "Can't get TRTCalibrator from resource manager!";
|
|
||||||
return tensorflow::errors::Unknown(
|
|
||||||
"Can't get TRTCalibrator from resource manager!");
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(calib_rm->Cleanup(container_name));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tensorflow::Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::Status ConvertGraphDefToTensorRT(
|
tensorflow::Status ConvertGraphDefToTensorRT(
|
||||||
const tensorflow::GraphDef& graph_def,
|
const tensorflow::GraphDef& graph_def,
|
||||||
const std::vector<string>& output_names, size_t max_batch_size,
|
const std::vector<string>& output_names, size_t max_batch_size,
|
||||||
|
@ -85,12 +85,6 @@ struct ConversionParams {
|
|||||||
std::vector<int> cached_engine_batches; // list of cached engines
|
std::vector<int> cached_engine_batches; // list of cached engines
|
||||||
};
|
};
|
||||||
|
|
||||||
// This method extracts calibration information from the resource managers
|
|
||||||
// and puts them in to engine nodedefs.
|
|
||||||
tensorflow::Status ConvertCalibGraphToInferGraph(
|
|
||||||
const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def,
|
|
||||||
bool is_dyn_op);
|
|
||||||
|
|
||||||
// - max_batch_size: maximum batch size which can be used for inference for
|
// - max_batch_size: maximum batch size which can be used for inference for
|
||||||
// optimization targets inference run with max batch size.
|
// optimization targets inference run with max batch size.
|
||||||
// - max_workspace_size_bytes: The upper bound of memory allowance for engine
|
// - max_workspace_size_bytes: The upper bound of memory allowance for engine
|
||||||
|
@ -30,7 +30,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
|
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h"
|
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
|
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
@ -379,6 +378,32 @@ tensorflow::Status CreateBroadcastableScalarConstant(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert an axis from TF format to TRT format while validating. TF format
|
||||||
|
// includes the batch dimension, while TRT does not. TF can also use negative
|
||||||
|
// indices.
|
||||||
|
// TODO(tmorris): Use this method in more ops.
|
||||||
|
tensorflow::Status ConvertAxis(int tf_axis, int trt_nb_dims,
|
||||||
|
absl::string_view node_name, int* trt_axis) {
|
||||||
|
const int tf_nb_dims = trt_nb_dims + 1;
|
||||||
|
// Check bounds.
|
||||||
|
if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"Axis value of ", tf_axis, " is out of bounds, must be in range [",
|
||||||
|
-tf_nb_dims, ", ", tf_nb_dims, "), at ", node_name);
|
||||||
|
}
|
||||||
|
// Make negative axis positive.
|
||||||
|
if (tf_axis < 0) tf_axis += tf_nb_dims;
|
||||||
|
// Don't allow axis to be the batch dimension.
|
||||||
|
if (tf_axis == 0) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"TensorRT does not allow manipulation of the batch dimension, at ",
|
||||||
|
node_name);
|
||||||
|
}
|
||||||
|
// Remove batch dimension.
|
||||||
|
*trt_axis = tf_axis - 1;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
inline bool DimsEqual(const nvinfer1::Dims& dim_l,
|
inline bool DimsEqual(const nvinfer1::Dims& dim_l,
|
||||||
const nvinfer1::Dims& dim_r) {
|
const nvinfer1::Dims& dim_r) {
|
||||||
if (dim_l.nbDims != dim_r.nbDims) {
|
if (dim_l.nbDims != dim_r.nbDims) {
|
||||||
@ -3413,6 +3438,29 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) {
|
|||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::Status ConvertGather(OpConverterParams* params) {
|
||||||
|
const auto& inputs = params->inputs;
|
||||||
|
const auto& node_def = params->node_def;
|
||||||
|
TF_RETURN_IF_ERROR(CheckInputsWeights(
|
||||||
|
*params, {{"params", false}, {"indices", false}, {"axis", true}}));
|
||||||
|
absl::Span<const int> axis = inputs.at(2).weights().GetSpan<int>();
|
||||||
|
if (axis.size() != 1) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"Axis for GatherV2 must be a scalar, at ", node_def.name());
|
||||||
|
}
|
||||||
|
int trt_axis = 0;
|
||||||
|
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims,
|
||||||
|
node_def.name(), &trt_axis));
|
||||||
|
if (params->validation_only) return Status::OK();
|
||||||
|
|
||||||
|
nvinfer1::IGatherLayer* layer = params->converter->network()->addGather(
|
||||||
|
*const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor()),
|
||||||
|
*const_cast<nvinfer1::ITensor*>(inputs.at(1).tensor()), trt_axis);
|
||||||
|
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||||
|
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
tensorflow::Status ConvertMatMulHelper(OpConverterParams* params,
|
tensorflow::Status ConvertMatMulHelper(OpConverterParams* params,
|
||||||
TRT_TensorOrWeights tensor_input,
|
TRT_TensorOrWeights tensor_input,
|
||||||
TRT_ShapedWeights weights_raw,
|
TRT_ShapedWeights weights_raw,
|
||||||
@ -3643,6 +3691,7 @@ static void RegisterValidatableOpConverters(
|
|||||||
(*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput;
|
(*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput;
|
||||||
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
|
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
|
||||||
(*registration)["ExpandDims"] = ConvertExpandDims;
|
(*registration)["ExpandDims"] = ConvertExpandDims;
|
||||||
|
(*registration)["GatherV2"] = ConvertGather;
|
||||||
(*registration)["LeakyRelu"] = ConvertLeakyRelu;
|
(*registration)["LeakyRelu"] = ConvertLeakyRelu;
|
||||||
(*registration)["MatMul"] = ConvertMatMul;
|
(*registration)["MatMul"] = ConvertMatMul;
|
||||||
(*registration)["Pad"] = ConvertPad;
|
(*registration)["Pad"] = ConvertPad;
|
||||||
|
@ -190,6 +190,11 @@ class TRT_ShapedWeights {
|
|||||||
|
|
||||||
string DebugString() const;
|
string DebugString() const;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
absl::Span<const T> GetSpan() const {
|
||||||
|
return absl::Span<const T>(tensor_.flat<T>().data(), count());
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(aaroey): make these private.
|
// TODO(aaroey): make these private.
|
||||||
nvinfer1::Dims shape_; // Note: shape.type[] is not used.
|
nvinfer1::Dims shape_; // Note: shape.type[] is not used.
|
||||||
tensorflow::DataType type_;
|
tensorflow::DataType type_;
|
||||||
|
@ -3129,6 +3129,126 @@ TEST_F(OpConverterTest, ConvertTopK) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <DataType dtype>
|
||||||
|
void TestConvertGather(OpConverterTest* test) {
|
||||||
|
typedef typename EnumToDataType<dtype>::Type CType;
|
||||||
|
|
||||||
|
// Get the NodeDef for GatherV2.
|
||||||
|
Scope s = Scope::NewRootScope();
|
||||||
|
auto params = ops::Placeholder(s.WithOpName("params"), dtype);
|
||||||
|
auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
|
||||||
|
auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
|
||||||
|
auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis);
|
||||||
|
const NodeDef& node_def = gather.operation.node()->def();
|
||||||
|
|
||||||
|
struct TestParams {
|
||||||
|
std::vector<int> params_dims;
|
||||||
|
std::vector<int> indices_dims;
|
||||||
|
std::vector<int> indices;
|
||||||
|
int axis;
|
||||||
|
std::vector<int> expected_output_dims;
|
||||||
|
std::vector<int> expected_output;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Input is the same {1, 2, 3, 4, 5, 6} for all cases.
|
||||||
|
const int kGatherOKCases = 5;
|
||||||
|
TestParams ok_params[kGatherOKCases] = {
|
||||||
|
// Vector indices (output is rank(params)).
|
||||||
|
TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1}, {1, 4}},
|
||||||
|
TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1}, {2, 5}},
|
||||||
|
TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1}, {3, 6}},
|
||||||
|
TestParams{{1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 3}, {3, 1, 2, 6, 4, 5}},
|
||||||
|
// Higher rank indices (output is rank(params) + rank(indices) - 1).
|
||||||
|
TestParams{{1, 2, 3}, {1, 1}, {0}, 2, {1, 1, 1, 3}, {1, 2, 3}},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Ok.
|
||||||
|
for (int i = 0; i < kGatherOKCases; i++) {
|
||||||
|
test->Reset();
|
||||||
|
test->AddTestTensor("params", ok_params[i].params_dims, 1,
|
||||||
|
TfDataTypeToTrt(dtype));
|
||||||
|
test->AddTestTensor("indices", ok_params[i].indices_dims, 1,
|
||||||
|
nvinfer1::DataType::kINT32);
|
||||||
|
test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis});
|
||||||
|
test->RunValidationAndConversion(node_def);
|
||||||
|
TRT_TensorOrWeights output;
|
||||||
|
TF_EXPECT_OK(test->GetTensorOrWeights("my_gather", &output));
|
||||||
|
EXPECT_TRUE(output.is_tensor());
|
||||||
|
ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
|
||||||
|
output.tensor()->getDimensions());
|
||||||
|
|
||||||
|
// Create input in CType and convert expected output to CType.
|
||||||
|
std::vector<CType> inputs = {CType(1), CType(2), CType(3),
|
||||||
|
CType(4), CType(5), CType(6)};
|
||||||
|
std::vector<CType> converted_expected_output(
|
||||||
|
ok_params[i].expected_output.begin(),
|
||||||
|
ok_params[i].expected_output.end());
|
||||||
|
|
||||||
|
const DataVec input_data{
|
||||||
|
{"params", test::AsTensor<CType>(inputs)},
|
||||||
|
{"indices", test::AsTensor<int32>(ok_params[i].indices)}};
|
||||||
|
DataVec output_data{
|
||||||
|
{"my_gather",
|
||||||
|
ConstructTensor<CType>(ok_params[i].expected_output.size())}};
|
||||||
|
test->BuildAndRun(input_data, &output_data);
|
||||||
|
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||||
|
ElementsAreArray(converted_expected_output));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OpConverterTest, ConvertGather) {
|
||||||
|
{
|
||||||
|
// Input list is empty, should fail.
|
||||||
|
NodeDef node_def = MakeNodeDef("my_gather", "GatherV2", {});
|
||||||
|
RunValidationAndConversion(
|
||||||
|
node_def, error::INVALID_ARGUMENT,
|
||||||
|
"GatherV2 got 0 inputs but expected 3, at my_gather");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the NodeDef for GatherV2.
|
||||||
|
Scope s = Scope::NewRootScope();
|
||||||
|
auto params = ops::Placeholder(s.WithOpName("params"), DT_FLOAT);
|
||||||
|
auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
|
||||||
|
auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
|
||||||
|
auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis);
|
||||||
|
const NodeDef& node_def = gather.operation.node()->def();
|
||||||
|
{
|
||||||
|
// Axis is a tensor, should fail.
|
||||||
|
Reset();
|
||||||
|
AddTestTensor("params", {1, 2, 3});
|
||||||
|
AddTestTensor("indices", {2});
|
||||||
|
AddTestTensor("axis", {1});
|
||||||
|
RunValidationAndConversion(
|
||||||
|
node_def, error::UNIMPLEMENTED,
|
||||||
|
"The input \"axis\" for GatherV2 must be a constant, at my_gather");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Axis is out of bounds, should fail.
|
||||||
|
Reset();
|
||||||
|
AddTestTensor("params", {1, 2, 3});
|
||||||
|
AddTestTensor("indices", {2});
|
||||||
|
AddTestWeights<int32>("axis", {1}, {4});
|
||||||
|
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
||||||
|
"Axis value of 4 is out of bounds, must be in "
|
||||||
|
"range [-4, 4), at my_gather");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Axis is batch dimension, should fail.
|
||||||
|
Reset();
|
||||||
|
AddTestTensor("params", {1, 2, 3});
|
||||||
|
AddTestTensor("indices", {2});
|
||||||
|
AddTestWeights<int32>("axis", {1}, {0});
|
||||||
|
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||||
|
"TensorRT does not allow manipulation of the "
|
||||||
|
"batch dimension, at my_gather");
|
||||||
|
}
|
||||||
|
|
||||||
|
Reset();
|
||||||
|
TestConvertGather<DT_FLOAT>(this);
|
||||||
|
TestConvertGather<DT_HALF>(this);
|
||||||
|
TestConvertGather<DT_INT32>(this);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace convert
|
} // namespace convert
|
||||||
} // namespace tensorrt
|
} // namespace tensorrt
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h"
|
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||||
@ -295,27 +294,6 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
|
|||||||
return this->AllocateCalibrationResources(ctx, cr);
|
return this->AllocateCalibrationResources(ctx, cr);
|
||||||
}}));
|
}}));
|
||||||
tensorflow::core::ScopedUnref calib_sc(calib_res);
|
tensorflow::core::ScopedUnref calib_sc(calib_res);
|
||||||
// TODO(aaroey): here we also add the resource to the ResourceMgr singleton.
|
|
||||||
// This is needed before we migrate all uses of calib_graph_to_infer_graph()
|
|
||||||
// to the new calibration workflow. After that we'll remove this block.
|
|
||||||
{
|
|
||||||
auto deprecated_rm =
|
|
||||||
TRTResourceManager::instance()->getManager("TRTCalibration");
|
|
||||||
TRTCalibrationResource* copied_resource = nullptr;
|
|
||||||
// Check whether the resource exists, and create it if not.
|
|
||||||
if (deprecated_rm->Lookup(funcdef_name_, "Calibrator", &copied_resource)
|
|
||||||
.ok()) {
|
|
||||||
// Do nothing if the resource exists.
|
|
||||||
copied_resource->Unref();
|
|
||||||
} else {
|
|
||||||
copied_resource = calib_res;
|
|
||||||
// Increase the refcount by 1 then transfer the ownership of that refcount
|
|
||||||
// to the ResourceMgr singleton.
|
|
||||||
copied_resource->Ref();
|
|
||||||
OP_REQUIRES_OK(ctx, deprecated_rm->Create(funcdef_name_, "Calibrator",
|
|
||||||
copied_resource));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
int num_inputs = ctx->num_inputs();
|
int num_inputs = ctx->num_inputs();
|
||||||
// Pass input data to calibrator
|
// Pass input data to calibrator
|
||||||
std::unordered_map<string, void*> input_data;
|
std::unordered_map<string, void*> input_data;
|
||||||
|
@ -38,16 +38,23 @@ def load_trt_ops():
|
|||||||
if _trt_ops_so:
|
if _trt_ops_so:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# pylint: disable=g-import-not-at-top,unused-variable
|
||||||
|
# This registers the TRT ops, it doesn't require loading TRT library.
|
||||||
|
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op
|
||||||
|
# pylint: enable=g-import-not-at-top,unused-variable
|
||||||
|
except ImportError as e:
|
||||||
|
print("**** Failed to import TF-TRT ops. This is because the binary was "
|
||||||
|
"not built with CUDA or TensorRT enabled. ****")
|
||||||
|
raise e
|
||||||
|
|
||||||
# TODO(laigd): we should load TF-TRT kernels here as well after removing the
|
# TODO(laigd): we should load TF-TRT kernels here as well after removing the
|
||||||
# swig binding.
|
# swig binding.
|
||||||
try:
|
try:
|
||||||
# TODO(lagid): It is not known why these unused imports were introduced.
|
# pylint: disable=g-import-not-at-top
|
||||||
# Investigate and get rid of these, if not required.
|
|
||||||
# pylint: disable=unused-import,g-import-not-at-top,unused-variable
|
|
||||||
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op
|
|
||||||
from tensorflow.python.framework import load_library
|
from tensorflow.python.framework import load_library
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.python.platform import resource_loader
|
||||||
# pylint: enable=unused-import,g-import-not-at-top,unused-variable
|
# pylint: enable=g-import-not-at-top
|
||||||
|
|
||||||
_trt_ops_so = load_library.load_op_library(
|
_trt_ops_so = load_library.load_op_library(
|
||||||
resource_loader.get_path_to_datafile("_trt_ops.so"))
|
resource_loader.get_path_to_datafile("_trt_ops.so"))
|
||||||
|
@ -19,7 +19,9 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "re2/re2.h"
|
#include "re2/re2.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
|
@ -16,8 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_
|
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_
|
||||||
#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_
|
#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_
|
||||||
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
|
@ -1,45 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace tensorrt {
|
|
||||||
|
|
||||||
std::shared_ptr<TRTResourceManager>
|
|
||||||
tensorflow::tensorrt::TRTResourceManager::instance() {
|
|
||||||
static std::shared_ptr<TRTResourceManager> instance_(new TRTResourceManager);
|
|
||||||
return instance_;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<tensorflow::ResourceMgr>
|
|
||||||
tensorflow::tensorrt::TRTResourceManager::getManager(const string& op_name) {
|
|
||||||
// mutex is held for lookup only. Most instantiations where mutex will be held
|
|
||||||
// longer will be during op creation and should be ok.
|
|
||||||
tensorflow::mutex_lock lock(map_mutex_);
|
|
||||||
auto s = managers_.find(op_name);
|
|
||||||
if (s == managers_.end()) {
|
|
||||||
auto it = managers_.emplace(
|
|
||||||
op_name, std::make_shared<tensorflow::ResourceMgr>(op_name));
|
|
||||||
VLOG(1) << "Returning a new manager " << op_name;
|
|
||||||
return it.first->second;
|
|
||||||
}
|
|
||||||
VLOG(1) << "Returning old manager " << op_name;
|
|
||||||
return s->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tensorrt
|
|
||||||
} // namespace tensorflow
|
|
@ -1,45 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_
|
|
||||||
#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace tensorrt {
|
|
||||||
|
|
||||||
class TRTResourceManager {
|
|
||||||
TRTResourceManager() = default;
|
|
||||||
|
|
||||||
public:
|
|
||||||
static std::shared_ptr<TRTResourceManager> instance();
|
|
||||||
// returns a manager for given op, if it doesn't exists it creates one
|
|
||||||
std::shared_ptr<tensorflow::ResourceMgr> getManager(const string& op_name);
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::unordered_map<string, std::shared_ptr<tensorflow::ResourceMgr>>
|
|
||||||
managers_;
|
|
||||||
tensorflow::mutex map_mutex_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace tensorrt
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_
|
|
@ -24,7 +24,7 @@ package(
|
|||||||
)
|
)
|
||||||
|
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
|
||||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
|
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library")
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tf2xla_supported_ops_lib",
|
name = "tf2xla_supported_ops_lib",
|
||||||
@ -60,6 +60,14 @@ xla_proto_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
xla_py_proto_library(
|
||||||
|
name = "tf2xla_py",
|
||||||
|
has_services = False,
|
||||||
|
api_version = 2,
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [":tf2xla_proto"],
|
||||||
|
)
|
||||||
|
|
||||||
xla_proto_library(
|
xla_proto_library(
|
||||||
name = "host_compute_metadata_proto",
|
name = "host_compute_metadata_proto",
|
||||||
srcs = ["host_compute_metadata.proto"],
|
srcs = ["host_compute_metadata.proto"],
|
||||||
@ -283,6 +291,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -107,11 +107,13 @@ tf_kernel_library(
|
|||||||
"xla_pad_op.cc",
|
"xla_pad_op.cc",
|
||||||
"xla_reduce_op.cc",
|
"xla_reduce_op.cc",
|
||||||
"xla_select_and_scatter_op.cc",
|
"xla_select_and_scatter_op.cc",
|
||||||
|
"xla_self_adjoint_eig_op.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"index_ops.h",
|
"index_ops.h",
|
||||||
"shape_util.h",
|
"shape_util.h",
|
||||||
],
|
],
|
||||||
|
tags = ["optonly"],
|
||||||
deps = [
|
deps = [
|
||||||
":conv_op_helpers",
|
":conv_op_helpers",
|
||||||
":if_op",
|
":if_op",
|
||||||
@ -143,6 +145,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/compiler/xla/client/lib:prng",
|
"//tensorflow/compiler/xla/client/lib:prng",
|
||||||
"//tensorflow/compiler/xla/client/lib:qr",
|
"//tensorflow/compiler/xla/client/lib:qr",
|
||||||
"//tensorflow/compiler/xla/client/lib:quantize",
|
"//tensorflow/compiler/xla/client/lib:quantize",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
|
||||||
"//tensorflow/compiler/xla/client/lib:sorting",
|
"//tensorflow/compiler/xla/client/lib:sorting",
|
||||||
"//tensorflow/core:bitwise_ops_op_lib",
|
"//tensorflow/core:bitwise_ops_op_lib",
|
||||||
"//tensorflow/core:control_flow_ops_op_lib",
|
"//tensorflow/core:control_flow_ops_op_lib",
|
||||||
|
@ -104,7 +104,7 @@ class SizeOp : public XlaOpKernel {
|
|||||||
for (int64 i = 0; i < rank; ++i) {
|
for (int64 i = 0; i < rank; ++i) {
|
||||||
size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i));
|
size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i));
|
||||||
}
|
}
|
||||||
size = xla::ConvertElementType(size, xla::S32);
|
size = xla::ConvertElementType(size, ctx->output_xla_type(0));
|
||||||
ctx->SetOutput(0, size);
|
ctx->SetOutput(0, size);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/core/framework/bounds_check.h"
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
@ -35,6 +36,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/kernels/concat_lib.h"
|
#include "tensorflow/core/kernels/concat_lib.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
@ -69,6 +71,43 @@ class TensorListLengthOp : public XlaOpKernel {
|
|||||||
|
|
||||||
REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp);
|
REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp);
|
||||||
|
|
||||||
|
// Creates an empty list with size (leading_dim, *element_shape) if
|
||||||
|
// element_shape is known at compile time. Otherwise creates one with size
|
||||||
|
// (leading_dim, 0) which gets initialized later in `GetInitializedList`.
|
||||||
|
Status CreateZerosList(XlaOpKernelContext* ctx, int element_shape_index,
|
||||||
|
int64 leading_dim, DataType dtype, xla::XlaOp* list) {
|
||||||
|
TensorShape list_shape;
|
||||||
|
list_shape.AddDim(leading_dim);
|
||||||
|
xla::XlaOp element_shape_handle = ctx->Input(element_shape_index);
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
bool is_element_shape_compile_time_const,
|
||||||
|
element_shape_handle.builder()->IsConstant(element_shape_handle));
|
||||||
|
PartialTensorShape partial_element_shape;
|
||||||
|
if (is_element_shape_compile_time_const) {
|
||||||
|
TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape(
|
||||||
|
element_shape_index, &partial_element_shape));
|
||||||
|
}
|
||||||
|
if (is_element_shape_compile_time_const &&
|
||||||
|
partial_element_shape.IsFullyDefined()) {
|
||||||
|
TensorShape element_shape;
|
||||||
|
partial_element_shape.AsTensorShape(&element_shape);
|
||||||
|
list_shape.AppendShape(element_shape);
|
||||||
|
} else {
|
||||||
|
// If element_shape is not a compile time constant or if it is not fully
|
||||||
|
// defined we will have to wait for the first write call to fully allocate
|
||||||
|
// the array.
|
||||||
|
// TODO(srbs): We are using element_shape of [0] as a proxy to denote an
|
||||||
|
// uninitialized list. A better implementation may be to represent the
|
||||||
|
// list as a 3-tuple containining an explicit "initialized" flag. However,
|
||||||
|
// we would still need to create a dummy tensor for the first tuple
|
||||||
|
// element.
|
||||||
|
list_shape.AddDim(0);
|
||||||
|
}
|
||||||
|
*list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype),
|
||||||
|
list_shape.dim_sizes());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
class TensorListReserveOp : public XlaOpKernel {
|
class TensorListReserveOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
@ -76,20 +115,15 @@ class TensorListReserveOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
TensorShape element_shape;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape));
|
|
||||||
int64 num_elements;
|
int64 num_elements;
|
||||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
|
||||||
|
|
||||||
TensorShape tensor_shape;
|
xla::XlaOp list;
|
||||||
tensor_shape.AddDim(num_elements);
|
OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &list));
|
||||||
tensor_shape.AppendShape(element_shape);
|
|
||||||
|
|
||||||
xla::XlaBuilder* b = ctx->builder();
|
xla::XlaBuilder* b = ctx->builder();
|
||||||
ctx->SetTensorListOutput(
|
ctx->SetTensorListOutput(
|
||||||
0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_),
|
0, xla::Tuple(b, {list, xla::ConstantR0<int32>(b, num_elements)}));
|
||||||
tensor_shape.dim_sizes()),
|
|
||||||
xla::ConstantR0<int32>(b, num_elements)}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -110,8 +144,6 @@ class EmptyTensorListOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
TensorShape element_shape;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape));
|
|
||||||
int64 max_num_elements;
|
int64 max_num_elements;
|
||||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements));
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
@ -119,15 +151,13 @@ class EmptyTensorListOp : public XlaOpKernel {
|
|||||||
errors::InvalidArgument("XLA compilation requires a fixed tensor list "
|
errors::InvalidArgument("XLA compilation requires a fixed tensor list "
|
||||||
"size. Set the max number of elements."));
|
"size. Set the max number of elements."));
|
||||||
|
|
||||||
TensorShape tensor_shape;
|
xla::XlaOp list;
|
||||||
tensor_shape.AddDim(max_num_elements);
|
OP_REQUIRES_OK(ctx,
|
||||||
tensor_shape.AppendShape(element_shape);
|
CreateZerosList(ctx, 0, max_num_elements, dtype_, &list));
|
||||||
|
|
||||||
xla::XlaBuilder* b = ctx->builder();
|
xla::XlaBuilder* b = ctx->builder();
|
||||||
ctx->SetTensorListOutput(
|
ctx->SetTensorListOutput(
|
||||||
0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_),
|
0, xla::Tuple(b, {list, xla::ConstantR0<int32>(b, 0)}));
|
||||||
tensor_shape.dim_sizes()),
|
|
||||||
xla::ConstantR0<int32>(b, 0)}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -274,6 +304,36 @@ REGISTER_XLA_OP(
|
|||||||
Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"),
|
Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"),
|
||||||
TensorListFromTensorOp);
|
TensorListFromTensorOp);
|
||||||
|
|
||||||
|
// Returns the 0'th element of `tuple` containing the list tensor if it has been
|
||||||
|
// initialized already else creates one lazily. This allows lazy initialization
|
||||||
|
// of the list on the first call to SetItem or PushBack.
|
||||||
|
Status GetInitializedList(XlaOpKernelContext* ctx, const xla::XlaOp& tuple,
|
||||||
|
const TensorShape& element_shape, DataType dtype,
|
||||||
|
xla::XlaOp* list) {
|
||||||
|
*list = xla::GetTupleElement(tuple, 0);
|
||||||
|
TensorShape list_shape;
|
||||||
|
TF_RETURN_IF_ERROR(GetTensorListShape(ctx->builder(), tuple, &list_shape));
|
||||||
|
int64 leading_dim = list_shape.dim_size(0);
|
||||||
|
TensorShape list_element_shape = list_shape;
|
||||||
|
list_element_shape.RemoveDim(0);
|
||||||
|
// This checks for the lazy initialization contract set by CreateEmptyList.
|
||||||
|
// In TensorListReserve if the element_shape is not known at compile time,
|
||||||
|
// it creates a list with shape [leading_dim, 0].
|
||||||
|
if (element_shape != list_element_shape) {
|
||||||
|
if (list_element_shape.num_elements() != 0) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Invalid shape of value in TensorListSetItem. Expected: ",
|
||||||
|
list_element_shape.DebugString(),
|
||||||
|
" Actual: ", element_shape.DebugString());
|
||||||
|
}
|
||||||
|
list_shape = element_shape;
|
||||||
|
list_shape.InsertDim(0, leading_dim);
|
||||||
|
*list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype),
|
||||||
|
list_shape.dim_sizes());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
class TensorListSetItemOp : public XlaOpKernel {
|
class TensorListSetItemOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
@ -285,7 +345,9 @@ class TensorListSetItemOp : public XlaOpKernel {
|
|||||||
xla::XlaOp tl = ctx->Input(0);
|
xla::XlaOp tl = ctx->Input(0);
|
||||||
TensorShape elem_shape = ctx->InputShape(2);
|
TensorShape elem_shape = ctx->InputShape(2);
|
||||||
|
|
||||||
xla::XlaOp ta = xla::GetTupleElement(tl, 0);
|
xla::XlaOp list;
|
||||||
|
OP_REQUIRES_OK(ctx, GetInitializedList(ctx, tl, elem_shape, dtype_, &list));
|
||||||
|
|
||||||
xla::XlaOp index = ctx->Input(1);
|
xla::XlaOp index = ctx->Input(1);
|
||||||
xla::XlaOp value = ctx->Input(2);
|
xla::XlaOp value = ctx->Input(2);
|
||||||
|
|
||||||
@ -299,8 +361,8 @@ class TensorListSetItemOp : public XlaOpKernel {
|
|||||||
auto update = xla::Reshape(value, slice_shape.dim_sizes());
|
auto update = xla::Reshape(value, slice_shape.dim_sizes());
|
||||||
|
|
||||||
ctx->SetTensorListOutput(
|
ctx->SetTensorListOutput(
|
||||||
0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices),
|
0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices),
|
||||||
index + xla::ConstantR0<int32>(b, 1)}));
|
xla::GetTupleElement(tl, 1)}));
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -319,11 +381,14 @@ class TensorListPushBackOp : public XlaOpKernel {
|
|||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
xla::XlaBuilder* b = ctx->builder();
|
xla::XlaBuilder* b = ctx->builder();
|
||||||
xla::XlaOp tl = ctx->Input(0);
|
xla::XlaOp list_tuple = ctx->Input(0);
|
||||||
TensorShape elem_shape = ctx->InputShape(1);
|
TensorShape elem_shape = ctx->InputShape(1);
|
||||||
|
|
||||||
xla::XlaOp ta = xla::GetTupleElement(tl, 0);
|
xla::XlaOp list;
|
||||||
xla::XlaOp index = xla::GetTupleElement(tl, 1);
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetInitializedList(ctx, list_tuple, elem_shape, dtype_, &list));
|
||||||
|
|
||||||
|
xla::XlaOp index = xla::GetTupleElement(list_tuple, 1);
|
||||||
xla::XlaOp value = ctx->Input(1);
|
xla::XlaOp value = ctx->Input(1);
|
||||||
|
|
||||||
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
|
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
|
||||||
@ -336,7 +401,7 @@ class TensorListPushBackOp : public XlaOpKernel {
|
|||||||
auto update = xla::Reshape(value, slice_shape.dim_sizes());
|
auto update = xla::Reshape(value, slice_shape.dim_sizes());
|
||||||
|
|
||||||
ctx->SetTensorListOutput(
|
ctx->SetTensorListOutput(
|
||||||
0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices),
|
0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices),
|
||||||
index + xla::ConstantR0<int32>(b, 1)}));
|
index + xla::ConstantR0<int32>(b, 1)}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,66 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
|
||||||
|
#include "tensorflow/core/lib/core/bits.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class XlaSelfAdjointEigOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit XlaSelfAdjointEigOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("lower", &lower_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_iter", &max_iter_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_));
|
||||||
|
}
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
auto result =
|
||||||
|
xla::SelfAdjointEig(ctx->Input(0), lower_, max_iter_, epsilon_);
|
||||||
|
ctx->SetOutput(0, result.w);
|
||||||
|
ctx->SetOutput(1, result.v);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool lower_;
|
||||||
|
int32 max_iter_;
|
||||||
|
float epsilon_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SelfAdjointEigV2Op : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit SelfAdjointEigV2Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
const TensorShape input_shape = ctx->InputShape("input");
|
||||||
|
int n = input_shape.dim_size(input_shape.dims() - 1);
|
||||||
|
// This is based on heuristics that approx log(n) sweep updates are needed.
|
||||||
|
// Note: the heuristics provides no theoretical guarantee, max_iter=100 and
|
||||||
|
// epsilon should be used to determine exit condition.
|
||||||
|
int max_iter = 2 * tensorflow::Log2Ceiling(n);
|
||||||
|
auto result = xla::SelfAdjointEig(ctx->Input(0), true, max_iter, 1e-6);
|
||||||
|
ctx->SetOutput(0, result.w);
|
||||||
|
ctx->SetOutput(1, result.v);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("XlaSelfAdjointEig").TypeConstraint("T", kFloatTypes),
|
||||||
|
XlaSelfAdjointEigOp);
|
||||||
|
REGISTER_XLA_OP(Name("SelfAdjointEigV2").TypeConstraint("T", kFloatTypes),
|
||||||
|
SelfAdjointEigV2Op);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -56,6 +56,41 @@ lhs_output: the broadcasted LHS tensor
|
|||||||
rhs_output: the broadcasted RHS tensor
|
rhs_output: the broadcasted RHS tensor
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("XlaSelfAdjointEig")
|
||||||
|
.Input("a: T")
|
||||||
|
.Attr("lower: bool")
|
||||||
|
.Attr("max_iter: int")
|
||||||
|
.Attr("epsilon: float")
|
||||||
|
.Output("w: T")
|
||||||
|
.Output("v: T")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
|
.Attr("T: numbertype")
|
||||||
|
.Doc(R"doc(
|
||||||
|
Computes the eigen decomposition of a batch of self-adjoint matrices
|
||||||
|
(Note: Only real inputs are supported).
|
||||||
|
|
||||||
|
Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in
|
||||||
|
tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for
|
||||||
|
i=0...N-1.
|
||||||
|
|
||||||
|
a: the input tensor.
|
||||||
|
|
||||||
|
lower: a boolean specifies whether the calculation is done with the lower
|
||||||
|
triangular part or the upper triangular part.
|
||||||
|
|
||||||
|
max_iter: maximum number of sweep update, i.e., the whole lower triangular
|
||||||
|
part or upper triangular part based on parameter lower. Heuristically, it has
|
||||||
|
been argued that approximatly logN sweeps are needed in practice (Ref: Golub &
|
||||||
|
van Loan "Matrix Computation").
|
||||||
|
|
||||||
|
epsilon: the tolerance ratio.
|
||||||
|
|
||||||
|
w: The eigenvalues in ascending order, each repeated according to its
|
||||||
|
multiplicity.
|
||||||
|
v: The column v[..., :, i] is the normalized eigenvector corresponding to the
|
||||||
|
eigenvalue w[..., i].
|
||||||
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("XlaConv")
|
REGISTER_OP("XlaConv")
|
||||||
.Input("lhs: T")
|
.Input("lhs: T")
|
||||||
.Input("rhs: T")
|
.Input("rhs: T")
|
||||||
|
@ -291,6 +291,10 @@ def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
|
|||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def self_adjoint_eig(a, lower, max_iter, epsilon):
|
||||||
|
return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon)
|
||||||
|
|
||||||
|
|
||||||
dynamic_slice = gen_xla_ops.xla_dynamic_slice
|
dynamic_slice = gen_xla_ops.xla_dynamic_slice
|
||||||
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
|
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
|
||||||
|
|
||||||
|
@ -185,9 +185,10 @@ Status BuildComputation(
|
|||||||
std::vector<xla::XlaOp> elems;
|
std::vector<xla::XlaOp> elems;
|
||||||
elems.reserve(retvals.size());
|
elems.reserve(retvals.size());
|
||||||
|
|
||||||
// Keeps track of which retvals have layout to update. The first element is
|
// Keeps track of the layout of each retval. If a retval is not in this list,
|
||||||
// the output index, second element is the new layout.
|
// a descending layout is used. The first element is the output index, second
|
||||||
std::vector<std::pair<int64, xla::Layout>> retval_to_update_layout;
|
// element is the new layout.
|
||||||
|
std::vector<std::pair<int64, xla::Layout>> retval_index_and_layout;
|
||||||
for (int i = 0; i < retvals.size(); ++i) {
|
for (int i = 0; i < retvals.size(); ++i) {
|
||||||
XlaCompiler::OutputDescription& output = (*outputs)[i];
|
XlaCompiler::OutputDescription& output = (*outputs)[i];
|
||||||
const XlaExpression& retval = retvals[i];
|
const XlaExpression& retval = retvals[i];
|
||||||
@ -216,7 +217,7 @@ Status BuildComputation(
|
|||||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
|
TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
|
||||||
output.shape, output.type));
|
output.shape, output.type));
|
||||||
value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
|
value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
|
||||||
retval_to_update_layout.emplace_back(elems.size(), shape.layout());
|
retval_index_and_layout.emplace_back(elems.size(), shape.layout());
|
||||||
} else if (it != retval_cores.end()) {
|
} else if (it != retval_cores.end()) {
|
||||||
// Apply the sharding to the output, if there is a core assignment.
|
// Apply the sharding to the output, if there is a core assignment.
|
||||||
value = identity_op(value);
|
value = identity_op(value);
|
||||||
@ -289,6 +290,11 @@ Status BuildComputation(
|
|||||||
// Ensures the correct sharding is applied to the output.
|
// Ensures the correct sharding is applied to the output.
|
||||||
handle = identity_op(handle);
|
handle = identity_op(handle);
|
||||||
|
|
||||||
|
// Set layout of the retval to device representation layout.
|
||||||
|
if (resource->representation_shape().has_value()) {
|
||||||
|
retval_index_and_layout.emplace_back(
|
||||||
|
elems.size(), resource->representation_shape()->layout());
|
||||||
|
}
|
||||||
elems.push_back(handle);
|
elems.push_back(handle);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -318,15 +324,15 @@ Status BuildComputation(
|
|||||||
computation->GetProgramShape());
|
computation->GetProgramShape());
|
||||||
*output_shape = program_shape.result();
|
*output_shape = program_shape.result();
|
||||||
// Update the output layout to the layout of retval.
|
// Update the output layout to the layout of retval.
|
||||||
for (auto& update : retval_to_update_layout) {
|
for (auto& index_and_layout : retval_index_and_layout) {
|
||||||
if (!always_return_tuple && elems.size() == 1) {
|
if (!always_return_tuple && elems.size() == 1) {
|
||||||
*output_shape->mutable_layout() = update.second;
|
*output_shape->mutable_layout() = index_and_layout.second;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::Shape* output_sub_shape =
|
xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape(
|
||||||
xla::ShapeUtil::GetMutableSubshape(output_shape, {update.first});
|
output_shape, {index_and_layout.first});
|
||||||
*output_sub_shape->mutable_layout() = update.second;
|
*output_sub_shape->mutable_layout() = index_and_layout.second;
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -277,6 +277,97 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) {
|
|||||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
|
EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests that the compiler can correctly propagate the layout assigned by
|
||||||
|
// shape_representation_fn_ to return types.
|
||||||
|
TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
|
||||||
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||||
|
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
||||||
|
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
|
||||||
|
// Adds an identity op around the resource to make sure identity ops propagate
|
||||||
|
// resources correctly.
|
||||||
|
auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
|
||||||
|
auto write = ops::AssignAddVariableOp(scope, identity, a);
|
||||||
|
auto read = ops::ReadVariableOp(
|
||||||
|
scope.WithControlDependencies(std::vector<Operation>{write}), var,
|
||||||
|
DT_INT32);
|
||||||
|
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
|
||||||
|
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
// Builds a description of the arguments.
|
||||||
|
std::vector<XlaCompiler::Argument> args(2);
|
||||||
|
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||||
|
args[0].type = DT_INT32;
|
||||||
|
args[0].shape = TensorShape({2, 3});
|
||||||
|
args[1].kind = XlaCompiler::Argument::kResource;
|
||||||
|
args[1].resource_kind = XlaResource::kVariable;
|
||||||
|
args[1].initialized = true;
|
||||||
|
args[1].type = DT_INT32;
|
||||||
|
args[1].shape = TensorShape({2, 3});
|
||||||
|
|
||||||
|
auto options = DefaultOptions();
|
||||||
|
options.shape_representation_fn =
|
||||||
|
[](const TensorShape& shape, DataType dt) -> xla::StatusOr<xla::Shape> {
|
||||||
|
xla::Shape xla_shape;
|
||||||
|
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
|
||||||
|
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
|
||||||
|
return xla_shape;
|
||||||
|
};
|
||||||
|
// Compiles the graph.
|
||||||
|
XlaCompiler compiler(options);
|
||||||
|
|
||||||
|
XlaCompiler::CompilationResult result;
|
||||||
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||||
|
std::move(graph), args, &result));
|
||||||
|
xla::Shape transposed =
|
||||||
|
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
|
||||||
|
// Check that the return shapes are correctly tranposed.
|
||||||
|
EXPECT_EQ(result.xla_output_shape,
|
||||||
|
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// The layout of resource variable shouldn't change after transpose
|
||||||
|
TEST_F(XlaCompilerTest, TransposeVariables) {
|
||||||
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||||
|
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
||||||
|
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
|
||||||
|
// Adds an identity op around the resource to make sure identity ops propagate
|
||||||
|
// resources correctly.
|
||||||
|
auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
|
||||||
|
auto write = ops::AssignAddVariableOp(scope, identity, a);
|
||||||
|
auto read = ops::ReadVariableOp(
|
||||||
|
scope.WithControlDependencies(std::vector<Operation>{write}), var,
|
||||||
|
DT_INT32);
|
||||||
|
auto transposed_read = ops::Transpose(scope, read, {1, 0});
|
||||||
|
auto reshape = ops::Reshape(scope, transposed_read, {2, 3});
|
||||||
|
auto d = ops::_Retval(scope.WithOpName("D"), reshape, 0);
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
// Builds a description of the arguments.
|
||||||
|
std::vector<XlaCompiler::Argument> args(2);
|
||||||
|
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||||
|
args[0].type = DT_INT32;
|
||||||
|
args[0].shape = TensorShape({2, 3});
|
||||||
|
args[1].kind = XlaCompiler::Argument::kResource;
|
||||||
|
args[1].resource_kind = XlaResource::kVariable;
|
||||||
|
args[1].initialized = true;
|
||||||
|
args[1].type = DT_INT32;
|
||||||
|
args[1].shape = TensorShape({2, 3});
|
||||||
|
// Compiles the graph.
|
||||||
|
XlaCompiler compiler(DefaultOptions());
|
||||||
|
|
||||||
|
XlaCompiler::CompilationResult result;
|
||||||
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose",
|
||||||
|
std::move(graph), args, &result));
|
||||||
|
xla::Shape transposed =
|
||||||
|
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0});
|
||||||
|
// Check that the return shapes are correctly tranposed.
|
||||||
|
EXPECT_EQ(result.xla_output_shape,
|
||||||
|
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
|
||||||
|
}
|
||||||
|
|
||||||
// Tests that the compiler doesn't reorder the parameters.
|
// Tests that the compiler doesn't reorder the parameters.
|
||||||
TEST_F(XlaCompilerTest, MixedOrderArguments) {
|
TEST_F(XlaCompilerTest, MixedOrderArguments) {
|
||||||
for (bool swap_order : {false, true}) {
|
for (bool swap_order : {false, true}) {
|
||||||
|
@ -319,6 +319,27 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status XlaOpKernelContext::ConstantInputAsPartialShape(
|
||||||
|
int index, PartialTensorShape* shape) {
|
||||||
|
xla::Literal literal;
|
||||||
|
TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
|
||||||
|
// If `literal` is a scalar it's value must be -1.
|
||||||
|
if (literal.shape().rank() == 0) {
|
||||||
|
int64 shape_val;
|
||||||
|
TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val));
|
||||||
|
if (shape_val != -1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Cannot convert value to PartialTensorShape: ", shape_val);
|
||||||
|
}
|
||||||
|
*shape = PartialTensorShape(); // Shape with unknown rank.
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
std::vector<int64> dims;
|
||||||
|
TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
|
||||||
|
*shape = PartialTensorShape(dims);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status XlaOpKernelContext::InputList(absl::string_view name,
|
Status XlaOpKernelContext::InputList(absl::string_view name,
|
||||||
std::vector<xla::XlaOp>* handles,
|
std::vector<xla::XlaOp>* handles,
|
||||||
std::vector<TensorShape>* shapes) {
|
std::vector<TensorShape>* shapes) {
|
||||||
@ -447,6 +468,16 @@ void XlaOpKernelContext::SetOutputExpression(int index,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) {
|
||||||
|
xla::PrimitiveType type;
|
||||||
|
Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type);
|
||||||
|
if (!status.ok()) {
|
||||||
|
SetStatus(status);
|
||||||
|
return xla::PRIMITIVE_TYPE_INVALID;
|
||||||
|
}
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
|
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
|
||||||
SetOutputExpression(
|
SetOutputExpression(
|
||||||
index,
|
index,
|
||||||
@ -503,6 +534,7 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type,
|
|||||||
handle = xla::Reshape(handle,
|
handle = xla::Reshape(handle,
|
||||||
xla::AsInt64Slice(representation_shape.dimensions()));
|
xla::AsInt64Slice(representation_shape.dimensions()));
|
||||||
}
|
}
|
||||||
|
variable->SetRepresentationShape(representation_shape);
|
||||||
return variable->SetValue(handle);
|
return variable->SetValue(handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,6 +138,10 @@ class XlaOpKernelContext {
|
|||||||
// Converts a constant 1D int32 or int64 tensor into a TensorShape.
|
// Converts a constant 1D int32 or int64 tensor into a TensorShape.
|
||||||
Status ConstantInputAsShape(int index, TensorShape* shape);
|
Status ConstantInputAsShape(int index, TensorShape* shape);
|
||||||
|
|
||||||
|
// Converts a constant 1D int32 or int64 tensor, or a scalar with value -1
|
||||||
|
// into a PartialTensorShape.
|
||||||
|
Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape);
|
||||||
|
|
||||||
// Returns the named list-valued immutable input in "list", as
|
// Returns the named list-valued immutable input in "list", as
|
||||||
// defined in the OpDef. If the named output is not list-valued,
|
// defined in the OpDef. If the named output is not list-valued,
|
||||||
// returns a one-element list.
|
// returns a one-element list.
|
||||||
@ -155,6 +159,11 @@ class XlaOpKernelContext {
|
|||||||
return context_->expected_output_dtype(index);
|
return context_->expected_output_dtype(index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the type of output `index` as an xla::PrimitiveType. If the type
|
||||||
|
// is not representable as an XLA type, sets an error status and returns
|
||||||
|
// xla::PRIMITIVE_TYPE_INVALID.
|
||||||
|
xla::PrimitiveType output_xla_type(int index);
|
||||||
|
|
||||||
// Sets output `index` to the XlaOp `handle`.
|
// Sets output `index` to the XlaOp `handle`.
|
||||||
// All outputs should be set using SetOutput and SetConstantOutput, not
|
// All outputs should be set using SetOutput and SetConstantOutput, not
|
||||||
// via the underlying OpKernelContext.
|
// via the underlying OpKernelContext.
|
||||||
|
@ -86,6 +86,12 @@ class XlaResource {
|
|||||||
// variables have new values that need to be written back.
|
// variables have new values that need to be written back.
|
||||||
const xla::XlaOp& initial_value() const { return initial_value_; }
|
const xla::XlaOp& initial_value() const { return initial_value_; }
|
||||||
|
|
||||||
|
// An xla shape that indicates how this resource variable is represented on
|
||||||
|
// device.
|
||||||
|
const absl::optional<xla::Shape>& representation_shape() const {
|
||||||
|
return representation_shape_;
|
||||||
|
}
|
||||||
|
|
||||||
// A variable is initialized if it has a value.
|
// A variable is initialized if it has a value.
|
||||||
bool initialized() const { return value_.valid(); }
|
bool initialized() const { return value_.valid(); }
|
||||||
|
|
||||||
@ -100,6 +106,11 @@ class XlaResource {
|
|||||||
// Sets the current value of the resource to an all-zero value.
|
// Sets the current value of the resource to an all-zero value.
|
||||||
Status SetZeroValue(xla::XlaBuilder* builder);
|
Status SetZeroValue(xla::XlaBuilder* builder);
|
||||||
|
|
||||||
|
// Sets the representational shape of the resource on device.
|
||||||
|
void SetRepresentationShape(const xla::Shape& shape) {
|
||||||
|
representation_shape_ = absl::make_optional(shape);
|
||||||
|
}
|
||||||
|
|
||||||
// Looks up the gradient for `source`, or creates it if it does not already
|
// Looks up the gradient for `source`, or creates it if it does not already
|
||||||
// exist. The call target must be an initialized TensorArray resource. A
|
// exist. The call target must be an initialized TensorArray resource. A
|
||||||
// TensorArray can have multiple named gradients; see the operator
|
// TensorArray can have multiple named gradients; see the operator
|
||||||
@ -160,6 +171,10 @@ class XlaResource {
|
|||||||
xla::XlaOp value_;
|
xla::XlaOp value_;
|
||||||
xla::XlaOp initial_value_;
|
xla::XlaOp initial_value_;
|
||||||
|
|
||||||
|
// An xla shape that indicates how this resource variable is represented on
|
||||||
|
// device.
|
||||||
|
absl::optional<xla::Shape> representation_shape_;
|
||||||
|
|
||||||
int64 max_array_size_ = -1;
|
int64 max_array_size_ = -1;
|
||||||
bool tensor_array_multiple_writes_aggregate_ = false;
|
bool tensor_array_multiple_writes_aggregate_ = false;
|
||||||
|
|
||||||
|
@ -452,11 +452,12 @@ cc_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "self_adjoint_eigen",
|
name = "self_adjoint_eig",
|
||||||
srcs = ["self_adjoint_eigen.cc"],
|
srcs = ["self_adjoint_eig.cc"],
|
||||||
hdrs = ["self_adjoint_eigen.h"],
|
hdrs = ["self_adjoint_eig.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":arithmetic",
|
":arithmetic",
|
||||||
|
":comparators",
|
||||||
":constants",
|
":constants",
|
||||||
":loops",
|
":loops",
|
||||||
":math",
|
":math",
|
||||||
@ -473,9 +474,12 @@ cc_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
xla_test(
|
xla_test(
|
||||||
name = "self_adjoint_eigen_test",
|
name = "self_adjoint_eig_test",
|
||||||
size = "medium",
|
srcs = ["self_adjoint_eig_test.cc"],
|
||||||
srcs = ["self_adjoint_eigen_test.cc"],
|
blacklisted_backends = [
|
||||||
|
"cpu",
|
||||||
|
"gpu",
|
||||||
|
],
|
||||||
real_hardware_only = True,
|
real_hardware_only = True,
|
||||||
shard_count = 10,
|
shard_count = 10,
|
||||||
tags = ["optonly"],
|
tags = ["optonly"],
|
||||||
@ -483,7 +487,7 @@ xla_test(
|
|||||||
":arithmetic",
|
":arithmetic",
|
||||||
":constants",
|
":constants",
|
||||||
":matrix",
|
":matrix",
|
||||||
":self_adjoint_eigen",
|
":self_adjoint_eig",
|
||||||
"//tensorflow/compiler/xla:array2d",
|
"//tensorflow/compiler/xla:array2d",
|
||||||
"//tensorflow/compiler/xla:array3d",
|
"//tensorflow/compiler/xla:array3d",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
|
@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h"
|
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/comparators.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||||
@ -42,7 +43,6 @@ namespace {
|
|||||||
struct SymmetricSchurDecomposition {
|
struct SymmetricSchurDecomposition {
|
||||||
XlaOp c; // cosine.
|
XlaOp c; // cosine.
|
||||||
XlaOp s; // sine.
|
XlaOp s; // sine.
|
||||||
XlaOp reduction; // Reduction in the off diagonal after applying G.
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix
|
// JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix
|
||||||
@ -51,7 +51,11 @@ struct SymmetricSchurDecomposition {
|
|||||||
struct JacobiUpdate {
|
struct JacobiUpdate {
|
||||||
XlaOp v;
|
XlaOp v;
|
||||||
XlaOp w;
|
XlaOp w;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FrobeniusNorms {
|
||||||
XlaOp off_diagonal_norm;
|
XlaOp off_diagonal_norm;
|
||||||
|
XlaOp total_norm;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Given an n-by-n symmetric A and integers p and q that satisfy 0 <= p < q < n,
|
// Given an n-by-n symmetric A and integers p and q that satisfy 0 <= p < q < n,
|
||||||
@ -79,10 +83,6 @@ StatusOr<SymmetricSchurDecomposition> SymmetricShurDecomposition2x2(XlaOp a,
|
|||||||
XlaBuilder* builder = a.builder();
|
XlaBuilder* builder = a.builder();
|
||||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||||
|
|
||||||
PrimitiveType type = a_shape.element_type();
|
|
||||||
|
|
||||||
const int64 num_dims = a_shape.rank();
|
|
||||||
|
|
||||||
auto zero = ScalarLike(a, 0.0);
|
auto zero = ScalarLike(a, 0.0);
|
||||||
auto one = ScalarLike(a, 1.0);
|
auto one = ScalarLike(a, 1.0);
|
||||||
auto two = ScalarLike(a, 2.0);
|
auto two = ScalarLike(a, 2.0);
|
||||||
@ -110,9 +110,7 @@ StatusOr<SymmetricSchurDecomposition> SymmetricShurDecomposition2x2(XlaOp a,
|
|||||||
|
|
||||||
schur.c = c * rnorm;
|
schur.c = c * rnorm;
|
||||||
schur.s = s * rnorm;
|
schur.s = s * rnorm;
|
||||||
schur.reduction =
|
|
||||||
Reduce(two * Square(pqs), zero, CreateScalarAddComputation(type, builder),
|
|
||||||
{num_dims - 2, num_dims - 1});
|
|
||||||
return schur;
|
return schur;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -196,12 +194,32 @@ StatusOr<JacobiUpdate> Update(JacobiUpdate jacobi_update, XlaOp p, XlaOp q,
|
|||||||
jacobi_update.v =
|
jacobi_update.v =
|
||||||
DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_q_new, {zero, q});
|
DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_q_new, {zero, q});
|
||||||
|
|
||||||
jacobi_update.off_diagonal_norm = Sqrt(
|
|
||||||
Max(Square(jacobi_update.off_diagonal_norm) - schur.reduction, pq_zero));
|
|
||||||
|
|
||||||
return jacobi_update;
|
return jacobi_update;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<FrobeniusNorms> ComputeFrobeniusNorms(XlaOp w) {
|
||||||
|
XlaBuilder* builder = w.builder();
|
||||||
|
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w));
|
||||||
|
const int64 num_dims = shape.rank();
|
||||||
|
auto frobenius_norm =
|
||||||
|
Sqrt(Reduce(Square(w), ScalarLike(w, 0.0),
|
||||||
|
CreateScalarAddComputation(shape.element_type(), builder),
|
||||||
|
{num_dims - 2, num_dims - 1}));
|
||||||
|
auto diag = GetMatrixDiagonal(w);
|
||||||
|
auto diag_square =
|
||||||
|
Reduce(Square(diag), ScalarLike(w, 0.0),
|
||||||
|
CreateScalarAddComputation(shape.element_type(), builder),
|
||||||
|
{num_dims - 2});
|
||||||
|
|
||||||
|
FrobeniusNorms frobenius_norms;
|
||||||
|
|
||||||
|
frobenius_norms.off_diagonal_norm =
|
||||||
|
Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w, 0.0)));
|
||||||
|
frobenius_norms.total_norm = frobenius_norm;
|
||||||
|
|
||||||
|
return frobenius_norms;
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<XlaOp>> WhileLoopFn(
|
StatusOr<std::vector<XlaOp>> WhileLoopFn(
|
||||||
absl::Span<const XlaOp> initial_values, //
|
absl::Span<const XlaOp> initial_values, //
|
||||||
int matrix_dimension, //
|
int matrix_dimension, //
|
||||||
@ -212,62 +230,108 @@ StatusOr<std::vector<XlaOp>> WhileLoopFn(
|
|||||||
auto while_cond_fn = [&](absl::Span<const XlaOp> values,
|
auto while_cond_fn = [&](absl::Span<const XlaOp> values,
|
||||||
XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
|
XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
|
||||||
auto k = values[0];
|
auto k = values[0];
|
||||||
auto off_diagonal_norm = values[5];
|
|
||||||
// tol = frobenius_norm * epsilon.
|
|
||||||
auto tol = values[6] * values[7];
|
|
||||||
|
|
||||||
auto max_sweeps = ScalarLike(k, max_sweep_updates);
|
auto max_sweeps = ScalarLike(k, max_sweep_updates);
|
||||||
|
|
||||||
auto sweep_update_cond = Gt(max_sweeps, k);
|
auto sweep_update_cond = Gt(max_sweeps, k);
|
||||||
|
|
||||||
auto tol_cond = ReduceAll(Lt(tol, off_diagonal_norm),
|
auto norms = ComputeFrobeniusNorms(values[2]).ValueOrDie();
|
||||||
|
auto tol = norms.total_norm * values[3];
|
||||||
|
auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm),
|
||||||
xla::ConstantR0<bool>(cond_builder, false),
|
xla::ConstantR0<bool>(cond_builder, false),
|
||||||
CreateScalarOrComputation(PRED, cond_builder));
|
CreateScalarOrComputation(PRED, cond_builder));
|
||||||
return And(tol_cond, sweep_update_cond);
|
|
||||||
|
return And(sweep_update_cond, tol_cond);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto while_body_fn =
|
auto while_body_fn =
|
||||||
[&](absl::Span<const XlaOp> values,
|
[&](absl::Span<const XlaOp> values,
|
||||||
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
|
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
|
||||||
auto zero = Zero(body_builder, index_type);
|
auto while_cond_fn_inner =
|
||||||
auto one = One(body_builder, index_type);
|
[&](absl::Span<const XlaOp> values_inner,
|
||||||
auto end_index = ScalarLike(one, matrix_dimension);
|
XlaBuilder* inner_cond_builder) -> StatusOr<XlaOp> {
|
||||||
|
auto p = values_inner[0];
|
||||||
|
return Lt(p, ScalarLike(p, matrix_dimension - 1));
|
||||||
|
};
|
||||||
|
|
||||||
|
auto while_body_fn_inner =
|
||||||
|
[&](absl::Span<const XlaOp> values_inner,
|
||||||
|
XlaBuilder* inner_body_builder) -> StatusOr<std::vector<XlaOp>> {
|
||||||
|
auto while_cond_fn_innermost =
|
||||||
|
[&](absl::Span<const XlaOp> values_innermost,
|
||||||
|
XlaBuilder* innermost_cond_builder) -> StatusOr<XlaOp> {
|
||||||
|
auto q = values_innermost[1];
|
||||||
|
return Lt(q, ScalarLike(q, matrix_dimension));
|
||||||
|
};
|
||||||
|
auto while_body_fn_innermost =
|
||||||
|
[&](absl::Span<const XlaOp> values_innermost,
|
||||||
|
XlaBuilder* innermost_body_builder)
|
||||||
|
-> StatusOr<std::vector<XlaOp>> {
|
||||||
|
auto p = values_innermost[0];
|
||||||
|
auto q = values_innermost[1];
|
||||||
|
|
||||||
|
JacobiUpdate jacobi_update;
|
||||||
|
jacobi_update.v = values_innermost[2];
|
||||||
|
jacobi_update.w = values_innermost[3];
|
||||||
|
|
||||||
|
auto tol = values_innermost[4];
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(jacobi_update,
|
||||||
|
Update(jacobi_update, p, q, tol, matrix_dimension));
|
||||||
|
|
||||||
|
std::vector<XlaOp> updated_values_innermost;
|
||||||
|
updated_values_innermost.reserve(values_innermost.size());
|
||||||
|
|
||||||
|
updated_values_innermost.push_back(p);
|
||||||
|
updated_values_innermost.push_back(q + ScalarLike(q, 1));
|
||||||
|
updated_values_innermost.push_back(jacobi_update.v);
|
||||||
|
updated_values_innermost.push_back(jacobi_update.w);
|
||||||
|
updated_values_innermost.push_back(tol);
|
||||||
|
|
||||||
|
return updated_values_innermost;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<XlaOp> values_innermost(5);
|
||||||
|
auto p = values_inner[0];
|
||||||
|
auto q = p + ScalarLike(p, 1);
|
||||||
|
values_innermost[0] = p; // index p.
|
||||||
|
values_innermost[1] = q; // index q.
|
||||||
|
values_innermost[2] = values_inner[1]; // v.
|
||||||
|
values_innermost[3] = values_inner[2]; // w.
|
||||||
|
values_innermost[4] = values_inner[3]; // tol.
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
values_innermost,
|
||||||
|
WhileLoopHelper(while_cond_fn_innermost, while_body_fn_innermost,
|
||||||
|
values_innermost, absl::StrCat(name, "-Innermost"),
|
||||||
|
inner_body_builder));
|
||||||
|
|
||||||
|
std::vector<XlaOp> updated_values_inner;
|
||||||
|
updated_values_inner.reserve(values_inner.size());
|
||||||
|
|
||||||
|
updated_values_inner.push_back(p + ScalarLike(p, 1));
|
||||||
|
updated_values_inner.push_back(values_innermost[2]);
|
||||||
|
updated_values_inner.push_back(values_innermost[3]);
|
||||||
|
updated_values_inner.push_back(values_innermost[4]);
|
||||||
|
return updated_values_inner;
|
||||||
|
};
|
||||||
// Indexes.
|
// Indexes.
|
||||||
XlaOp k = values[0];
|
XlaOp k = values[0];
|
||||||
XlaOp p = values[1];
|
|
||||||
XlaOp q = values[2];
|
|
||||||
|
|
||||||
JacobiUpdate jacobi_update;
|
std::vector<XlaOp> values_inner(4);
|
||||||
jacobi_update.v = values[3];
|
values_inner[0] = ScalarLike(k, 0); // index p.
|
||||||
jacobi_update.w = values[4];
|
values_inner[1] = values[1]; // v.
|
||||||
jacobi_update.off_diagonal_norm = values[5];
|
values_inner[2] = values[2]; // w.
|
||||||
|
values_inner[3] = values[3]; // tol.
|
||||||
XlaOp frobenius_norm = values[6];
|
TF_ASSIGN_OR_RETURN(
|
||||||
XlaOp tol = values[7];
|
values_inner,
|
||||||
|
WhileLoopHelper(while_cond_fn_inner, while_body_fn_inner, values_inner,
|
||||||
TF_ASSIGN_OR_RETURN(jacobi_update,
|
absl::StrCat(name, "-Inner"), body_builder));
|
||||||
Update(jacobi_update, p, q, tol, matrix_dimension));
|
|
||||||
|
|
||||||
std::vector<XlaOp> updated_values;
|
std::vector<XlaOp> updated_values;
|
||||||
updated_values.reserve(values.size());
|
updated_values.reserve(values_inner.size());
|
||||||
|
|
||||||
q = q + one;
|
updated_values.push_back(k + ScalarLike(k, 1));
|
||||||
p = Select(Eq(q, end_index), p + one, p);
|
updated_values.push_back(values_inner[1]);
|
||||||
k = Select(Eq(p, end_index - one), k + one, k);
|
updated_values.push_back(values_inner[2]);
|
||||||
p = Select(Eq(p, end_index - one), zero, p);
|
updated_values.push_back(values_inner[3]);
|
||||||
q = Select(Eq(q, end_index), p + one, q);
|
|
||||||
|
|
||||||
updated_values.push_back(k);
|
|
||||||
updated_values.push_back(p);
|
|
||||||
updated_values.push_back(q);
|
|
||||||
|
|
||||||
updated_values.push_back(jacobi_update.v);
|
|
||||||
updated_values.push_back(jacobi_update.w);
|
|
||||||
updated_values.push_back(jacobi_update.off_diagonal_norm);
|
|
||||||
|
|
||||||
updated_values.push_back(frobenius_norm);
|
|
||||||
updated_values.push_back(tol);
|
|
||||||
|
|
||||||
return updated_values;
|
return updated_values;
|
||||||
};
|
};
|
||||||
@ -278,6 +342,27 @@ StatusOr<std::vector<XlaOp>> WhileLoopFn(
|
|||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<SelfAdjointEigResult> SortByEigenvalues(SelfAdjointEigResult result) {
|
||||||
|
XlaBuilder* builder = result.v.builder();
|
||||||
|
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.v));
|
||||||
|
const int64 num_dims = shape.rank();
|
||||||
|
auto dimensions = shape.dimensions();
|
||||||
|
|
||||||
|
std::vector<int64> broadcast_dims(num_dims - 1);
|
||||||
|
std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
|
||||||
|
broadcast_dims[num_dims - 2] = num_dims - 1;
|
||||||
|
result.w = BroadcastInDim(result.w, dimensions, broadcast_dims);
|
||||||
|
|
||||||
|
XlaOp sort_result =
|
||||||
|
Sort({result.w, result.v},
|
||||||
|
CreateScalarLtComputation(
|
||||||
|
{shape.element_type(), shape.element_type()}, builder),
|
||||||
|
num_dims - 1);
|
||||||
|
result.w = GetMatrixDiagonal(GetTupleElement(sort_result, 0));
|
||||||
|
result.v = GetTupleElement(sort_result, 1);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// This is the cyclic Jacobi iteration. Please note that the eigenvalues are
|
// This is the cyclic Jacobi iteration. Please note that the eigenvalues are
|
||||||
@ -286,31 +371,35 @@ StatusOr<std::vector<XlaOp>> WhileLoopFn(
|
|||||||
// def jacobi(A):
|
// def jacobi(A):
|
||||||
// n, _ = A.shape
|
// n, _ = A.shape
|
||||||
// V = np.eye(n)
|
// V = np.eye(n)
|
||||||
// nfrob = np.sum(A ** 2)
|
// frobenius_norm = np.linalg.norm(A)
|
||||||
// ndiag = np.sum(np.diag(A) ** 2)
|
// diag_norm = np.linalg.norm(np.diag(A))
|
||||||
// off = nfrob - ndiag
|
// off_diag_norm = np.sqrt(
|
||||||
// while off > 1e-6 * nfrob:
|
// frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm)
|
||||||
|
// while off_diag_norm > 1e-6 * frobenius_norm:
|
||||||
// for p in range(n - 1):
|
// for p in range(n - 1):
|
||||||
// for q in range(p + 1, n):
|
// for q in range(p + 1, n):
|
||||||
// if off > 1e-6 * nfrob:
|
// c, s = sym_schur2x2(A, p, q)
|
||||||
// c, s = sym_schur2x2(A, p, q)
|
// A[[p, q], :] = np.matmul(np.array([[c, -s], [s, c]]),
|
||||||
// off = off - 2 * A[p, q] ** 2
|
// A[[p, q], :])
|
||||||
// A[[p, q], :] = np.matmul(np.array([[c, -s], [s, c]]),
|
// A[:, [p, q]] = np.matmul(A[:, [p, q]],
|
||||||
// A[[p, q], :])
|
// np.array([[c, s], [-s, c]]))
|
||||||
// A[:, [p, q]] = np.matmul(A[:, [p, q]],
|
// V[:, [p, q]] = np.matmul(V[:, [p, q]],
|
||||||
// np.array([[c, s], [-s, c]]))
|
|
||||||
// V[:, [p, q]] = np.matmul(V[:, [p, q]],
|
|
||||||
// np.array([[c, s], [-s, c]]))
|
// np.array([[c, s], [-s, c]]))
|
||||||
|
// frobenius_norm_sq = np.linalg.norm(A)
|
||||||
|
// diag_square_sum = np.linalg.norm(np.diag(A))
|
||||||
|
// off_diag_norm = np.sqrt(
|
||||||
|
// frobenius_norm - diag_norm) * np.sqrt(
|
||||||
|
// frobenius_norm + diag_norm)
|
||||||
//
|
//
|
||||||
// return A, V
|
// return A, V
|
||||||
//
|
//
|
||||||
// TODO(kuny): Implement parallel order Jacobi.
|
// TODO(kuny): Implement parallel order Jacobi.
|
||||||
//
|
//
|
||||||
SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower, int64 max_iter,
|
SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter,
|
||||||
float epsilon) {
|
float epsilon) {
|
||||||
XlaBuilder* builder = a.builder();
|
XlaBuilder* builder = a.builder();
|
||||||
auto return_error = [&](const Status& status) {
|
auto return_error = [&](const Status& status) {
|
||||||
SelfAdjointEigenResult result;
|
SelfAdjointEigResult result;
|
||||||
result.v = builder->ReportError(status);
|
result.v = builder->ReportError(status);
|
||||||
result.w = builder->ReportError(status);
|
result.w = builder->ReportError(status);
|
||||||
return result;
|
return result;
|
||||||
@ -348,33 +437,17 @@ SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower, int64 max_iter,
|
|||||||
batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
|
batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto zero = ScalarLike(a, 0.0);
|
|
||||||
auto tol = ScalarLike(a, epsilon);
|
auto tol = ScalarLike(a, epsilon);
|
||||||
|
|
||||||
auto v_init = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
|
auto v_init = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
|
||||||
auto w_init = Triangle(a, lower);
|
auto w_init = Triangle(a, lower);
|
||||||
w_init = w_init + TransposeInMinorDims(w_init) - w_init * v_init;
|
w_init = w_init + TransposeInMinorDims(w_init) - w_init * v_init;
|
||||||
|
|
||||||
auto frobenius_norm = Sqrt(Reduce(Square(w_init), zero,
|
|
||||||
CreateScalarAddComputation(type, builder),
|
|
||||||
{num_dims - 2, num_dims - 1}));
|
|
||||||
auto diag = GetMatrixDiagonal(w_init);
|
|
||||||
auto diag_square =
|
|
||||||
Reduce(Square(diag), zero, CreateScalarAddComputation(type, builder),
|
|
||||||
{num_dims - 2});
|
|
||||||
|
|
||||||
auto off_diagonal_init =
|
|
||||||
Sqrt(Max(Square(frobenius_norm) - diag_square, zero));
|
|
||||||
|
|
||||||
auto output_with_status = WhileLoopFn(
|
auto output_with_status = WhileLoopFn(
|
||||||
{
|
{
|
||||||
Zero(builder, S32), // k
|
Zero(builder, S32), // k
|
||||||
Zero(builder, S32), // p
|
v_init, // v
|
||||||
One(builder, S32), // q
|
w_init, // w
|
||||||
v_init, //
|
|
||||||
w_init, //
|
|
||||||
off_diagonal_init, //
|
|
||||||
frobenius_norm, //
|
|
||||||
tol, //
|
tol, //
|
||||||
}, //
|
}, //
|
||||||
n, //
|
n, //
|
||||||
@ -388,11 +461,11 @@ SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower, int64 max_iter,
|
|||||||
|
|
||||||
auto output = output_with_status.ValueOrDie();
|
auto output = output_with_status.ValueOrDie();
|
||||||
|
|
||||||
SelfAdjointEigenResult result;
|
SelfAdjointEigResult result;
|
||||||
result.v = output[3];
|
result.v = output[1];
|
||||||
result.w = GetMatrixDiagonal(output[4]);
|
result.w = GetMatrixDiagonal(output[2]);
|
||||||
|
|
||||||
return result;
|
return SortByEigenvalues(result).ValueOrDie();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_
|
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
@ -23,20 +23,18 @@ namespace xla {
|
|||||||
|
|
||||||
// The eigenvalue decomposition of a symmetric matrix, the original matrix is
|
// The eigenvalue decomposition of a symmetric matrix, the original matrix is
|
||||||
// recovered by v * w * v_t.
|
// recovered by v * w * v_t.
|
||||||
struct SelfAdjointEigenResult {
|
struct SelfAdjointEigResult {
|
||||||
// The i-th column is the normalized eigenvector corresponding to the
|
// The i-th column is the normalized eigenvector corresponding to the
|
||||||
// eigenvalue w[i]. Will return a matrix object if a is a matrix object.
|
// eigenvalue w[i]. Will return a matrix object if a is a matrix object.
|
||||||
XlaOp v;
|
XlaOp v;
|
||||||
// TODO(kuny): Sort the eigenvalues.
|
|
||||||
// The eigenvalues in ascending order, each repeated according to its
|
// The eigenvalues in ascending order, each repeated according to its
|
||||||
// multiplicity.
|
// multiplicity.
|
||||||
XlaOp w;
|
XlaOp w;
|
||||||
};
|
};
|
||||||
|
|
||||||
SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower = true,
|
SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower = true,
|
||||||
int64 max_iter = 100,
|
int64 max_iter = 100, float epsilon = 1e-6);
|
||||||
float epsilon = 1e-6);
|
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_
|
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h"
|
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/array2d.h"
|
#include "tensorflow/compiler/xla/array2d.h"
|
||||||
#include "tensorflow/compiler/xla/array3d.h"
|
#include "tensorflow/compiler/xla/array3d.h"
|
||||||
@ -32,7 +32,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
class SelfAdjointEigenTest : public ClientLibraryTestBase {
|
class SelfAdjointEigTest : public ClientLibraryTestBase {
|
||||||
protected:
|
protected:
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
ClientLibraryTestBase::SetUp();
|
ClientLibraryTestBase::SetUp();
|
||||||
@ -71,7 +71,7 @@ class SelfAdjointEigenTest : public ClientLibraryTestBase {
|
|||||||
}
|
}
|
||||||
void TearDown() override { ClientLibraryTestBase::TearDown(); }
|
void TearDown() override { ClientLibraryTestBase::TearDown(); }
|
||||||
|
|
||||||
Array3D<float> get_unit_matrix_3d(const Array3D<float>& matrix) {
|
Array3D<float> GetUnitMatrix3D(const Array3D<float>& matrix) {
|
||||||
Array3D<float> result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0);
|
Array3D<float> result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0);
|
||||||
for (int i = 0; i < matrix.n1(); ++i) {
|
for (int i = 0; i < matrix.n1(); ++i) {
|
||||||
for (int j = 0; j < matrix.n2(); ++j) {
|
for (int j = 0; j < matrix.n2(); ++j) {
|
||||||
@ -100,7 +100,7 @@ class SelfAdjointEigenTest : public ClientLibraryTestBase {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp ComputeMatmulVWVt(SelfAdjointEigenResult result, XlaBuilder* builder) {
|
XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) {
|
||||||
Shape shape = builder->GetShape(result.v).ValueOrDie();
|
Shape shape = builder->GetShape(result.v).ValueOrDie();
|
||||||
std::vector<int64> out_dims = shape.dimensions();
|
std::vector<int64> out_dims = shape.dimensions();
|
||||||
std::vector<int64> broadcast_dims(shape.rank() - 1);
|
std::vector<int64> broadcast_dims(shape.rank() - 1);
|
||||||
@ -140,69 +140,69 @@ class SelfAdjointEigenTest : public ClientLibraryTestBase {
|
|||||||
Array2D<int> wrong_type_4x4_;
|
Array2D<int> wrong_type_4x4_;
|
||||||
};
|
};
|
||||||
|
|
||||||
XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_2x4x4) {
|
XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
|
|
||||||
XlaOp a;
|
XlaOp a;
|
||||||
auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
|
auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
|
||||||
auto result = SelfAdjointEigen(a);
|
auto result = SelfAdjointEig(a);
|
||||||
ComputeMatmulVWVt(result, &builder);
|
ComputeMatmulVWVt(result, &builder);
|
||||||
|
|
||||||
ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
|
ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
|
||||||
ErrorSpec(1e-3, 1e-3));
|
ErrorSpec(1e-3, 1e-3));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_Lower_2x4x4) {
|
XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
|
|
||||||
XlaOp a;
|
XlaOp a;
|
||||||
auto a_data = CreateR3Parameter<float>(
|
auto a_data = CreateR3Parameter<float>(
|
||||||
ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a);
|
ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a);
|
||||||
auto result = SelfAdjointEigen(a);
|
auto result = SelfAdjointEig(a);
|
||||||
ComputeMatmulVWVt(result, &builder);
|
ComputeMatmulVWVt(result, &builder);
|
||||||
|
|
||||||
ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
|
ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
|
||||||
ErrorSpec(1e-3, 1e-3));
|
ErrorSpec(1e-3, 1e-3));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_Upper_2x4x4) {
|
XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
|
|
||||||
XlaOp a;
|
XlaOp a;
|
||||||
auto a_data = CreateR3Parameter<float>(
|
auto a_data = CreateR3Parameter<float>(
|
||||||
ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a);
|
ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a);
|
||||||
auto result = SelfAdjointEigen(a, false);
|
auto result = SelfAdjointEig(a, false);
|
||||||
ComputeMatmulVWVt(result, &builder);
|
ComputeMatmulVWVt(result, &builder);
|
||||||
|
|
||||||
ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
|
ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
|
||||||
ErrorSpec(1e-3, 1e-3));
|
ErrorSpec(1e-3, 1e-3));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(SelfAdjointEigenTest, Test_Orthogonality_2x4x4) {
|
XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
|
|
||||||
XlaOp a;
|
XlaOp a;
|
||||||
auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
|
auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
|
||||||
auto result = SelfAdjointEigen(a);
|
auto result = SelfAdjointEig(a);
|
||||||
BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
|
BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
|
||||||
|
|
||||||
ComputeAndCompareR3<float>(&builder, get_unit_matrix_3d(batch_3d_4x4_),
|
ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(batch_3d_4x4_),
|
||||||
{a_data.get()}, ErrorSpec(1e-3, 1e-3));
|
{a_data.get()}, ErrorSpec(1e-3, 1e-3));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(SelfAdjointEigenTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) {
|
XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
|
|
||||||
XlaOp a;
|
XlaOp a;
|
||||||
auto a_data = CreateR2Parameter<float>(low_rank_4x4_, 0, "a", &builder, &a);
|
auto a_data = CreateR2Parameter<float>(low_rank_4x4_, 0, "a", &builder, &a);
|
||||||
auto result = SelfAdjointEigen(a);
|
auto result = SelfAdjointEig(a);
|
||||||
ComputeMatmulVWVt(result, &builder);
|
ComputeMatmulVWVt(result, &builder);
|
||||||
|
|
||||||
ComputeAndCompareR2<float>(&builder, low_rank_4x4_, {a_data.get()},
|
ComputeAndCompareR2<float>(&builder, low_rank_4x4_, {a_data.get()},
|
||||||
ErrorSpec(1e-3, 1e-3));
|
ErrorSpec(1e-3, 1e-3));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(SelfAdjointEigenTest, Test_Eigen_8x8) {
|
XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
|
|
||||||
// This is computed by numpy.linalg.eigh with float32.
|
// This is computed by numpy.linalg.eigh with float32.
|
||||||
@ -211,21 +211,21 @@ XLA_TEST_F(SelfAdjointEigenTest, Test_Eigen_8x8) {
|
|||||||
|
|
||||||
XlaOp a;
|
XlaOp a;
|
||||||
auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
|
auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
|
||||||
auto result = SelfAdjointEigen(a);
|
auto result = SelfAdjointEig(a);
|
||||||
Sort(result.w);
|
Add(result.w, ZerosLike(result.w));
|
||||||
|
|
||||||
ComputeAndCompareR1<float>(&builder, expected, {a_data.get()},
|
ComputeAndCompareR1<float>(&builder, expected, {a_data.get()},
|
||||||
ErrorSpec(1e-3, 1e-3));
|
ErrorSpec(1e-3, 1e-3));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(SelfAdjointEigenTest, Test_Orthogonality_8x8) {
|
XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
|
|
||||||
float expected_vals = 1e-3;
|
float expected_vals = 1e-3;
|
||||||
|
|
||||||
XlaOp a;
|
XlaOp a;
|
||||||
auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
|
auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
|
||||||
auto result = SelfAdjointEigen(a);
|
auto result = SelfAdjointEig(a);
|
||||||
// np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2
|
// np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2
|
||||||
GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8),
|
GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8),
|
||||||
BatchDot(TransposeInMinorDims(result.v), result.v),
|
BatchDot(TransposeInMinorDims(result.v), result.v),
|
||||||
@ -235,66 +235,79 @@ XLA_TEST_F(SelfAdjointEigenTest, Test_Orthogonality_8x8) {
|
|||||||
ErrorSpec(1e-3, 1e-3));
|
ErrorSpec(1e-3, 1e-3));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(SelfAdjointEigenTest, Wrong_Type_Int) {
|
XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
|
|
||||||
XlaOp a;
|
XlaOp a;
|
||||||
auto a_data = CreateR2Parameter<int>(wrong_type_4x4_, 0, "a", &builder, &a);
|
auto a_data = CreateR2Parameter<int>(wrong_type_4x4_, 0, "a", &builder, &a);
|
||||||
auto result = SelfAdjointEigen(a);
|
auto result = SelfAdjointEig(a);
|
||||||
EXPECT_FALSE(result.v.valid());
|
EXPECT_FALSE(result.v.valid());
|
||||||
EXPECT_FALSE(result.w.valid());
|
EXPECT_FALSE(result.w.valid());
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_8x8) {
|
XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_8x8) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
int size = 8;
|
int size = 8;
|
||||||
Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
|
Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
|
||||||
XlaOp a;
|
XlaOp a;
|
||||||
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
|
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
|
||||||
auto result = SelfAdjointEigen(a);
|
auto result = SelfAdjointEig(a);
|
||||||
GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
|
GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
|
||||||
|
|
||||||
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
|
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
|
||||||
ErrorSpec(1e-2, 1e-2));
|
ErrorSpec(1e-3, 1e-3));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_16x16) {
|
XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_16x16) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
int size = 16;
|
int size = 16;
|
||||||
Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
|
Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
|
||||||
XlaOp a;
|
XlaOp a;
|
||||||
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
|
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
|
||||||
auto result = SelfAdjointEigen(a);
|
auto result = SelfAdjointEig(a);
|
||||||
GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
|
GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
|
||||||
|
|
||||||
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
|
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
|
||||||
ErrorSpec(1e-2, 1e-2));
|
ErrorSpec(1e-3, 1e-3));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_32x32) {
|
XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_32x32) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
int size = 32;
|
int size = 32;
|
||||||
Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
|
Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
|
||||||
XlaOp a;
|
XlaOp a;
|
||||||
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
|
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
|
||||||
auto result = SelfAdjointEigen(a);
|
auto result = SelfAdjointEig(a);
|
||||||
GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
|
GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
|
||||||
|
|
||||||
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
|
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
|
||||||
ErrorSpec(1e-2, 1e-2));
|
ErrorSpec(1e-3, 1e-3));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_64x64) {
|
XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_256x256) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
int size = 64;
|
int size = 256;
|
||||||
Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
|
Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
|
||||||
XlaOp a;
|
XlaOp a;
|
||||||
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
|
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
|
||||||
auto result = SelfAdjointEigen(a);
|
auto result = SelfAdjointEig(a);
|
||||||
GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
|
GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
|
||||||
|
|
||||||
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
|
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
|
||||||
ErrorSpec(1e-2, 1e-2));
|
ErrorSpec(1e-3, 1e-3));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_512x512) {
|
||||||
|
XlaBuilder builder(TestName());
|
||||||
|
int size = 512;
|
||||||
|
Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
|
||||||
|
XlaOp a;
|
||||||
|
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
|
||||||
|
auto result = SelfAdjointEig(a);
|
||||||
|
GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
|
||||||
|
|
||||||
|
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
|
||||||
|
ErrorSpec(1e-3, 1e-3));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
@ -36,7 +36,8 @@ XlaOp TopK(XlaOp input, int64 k) {
|
|||||||
XlaOp sort_result =
|
XlaOp sort_result =
|
||||||
Sort({Neg(input), iota_s32},
|
Sort({Neg(input), iota_s32},
|
||||||
CreateScalarLtComputation({input_shape.element_type(), S32},
|
CreateScalarLtComputation({input_shape.element_type(), S32},
|
||||||
iota_s32.builder()));
|
iota_s32.builder()),
|
||||||
|
last_dim, /*is_stable=*/true);
|
||||||
std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
|
std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
|
||||||
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
|
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
|
||||||
limit_indices[last_dim] = k;
|
limit_indices[last_dim] = k;
|
||||||
|
@ -81,9 +81,7 @@ XLA_TEST_F(SortingTest, TopKFullSort) {
|
|||||||
ComputeAndCompareR1<float>(&builder, inputs, {});
|
ComputeAndCompareR1<float>(&builder, inputs, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(b/122298745): Enable this test when the GPU backend supports stable
|
XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates) {
|
||||||
// sorting.
|
|
||||||
XLA_TEST_F(SortingTest, DISABLED_ON_GPU(TopKFullSortWithDuplicates)) {
|
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
XlaOp a;
|
XlaOp a;
|
||||||
auto a_data = CreateR1Parameter<int>({1, 1, 2, 2, 1}, 0, "a", &builder, &a);
|
auto a_data = CreateR1Parameter<int>({1, 1, 2, 2, 1}, 0, "a", &builder, &a);
|
||||||
|
@ -1663,14 +1663,16 @@ XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span<const XlaOp> values,
|
|||||||
Lt(first_lhs_param, first_rhs_param);
|
Lt(first_lhs_param, first_rhs_param);
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(auto comparator, b->Build());
|
TF_ASSIGN_OR_RETURN(auto comparator, b->Build());
|
||||||
return Sort(operands, comparator, dimension);
|
return Sort(operands, comparator, dimension, /*is_stable=*/false);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
|
XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
|
||||||
const XlaComputation& comparator, int64 dimension) {
|
const XlaComputation& comparator, int64 dimension,
|
||||||
|
bool is_stable) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
HloInstructionProto instr;
|
||||||
|
instr.set_is_stable(is_stable);
|
||||||
std::vector<const Shape*> operand_shape_ptrs;
|
std::vector<const Shape*> operand_shape_ptrs;
|
||||||
TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
|
TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
|
||||||
GetOperandShapes(operands));
|
GetOperandShapes(operands));
|
||||||
@ -3320,8 +3322,9 @@ XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values, int64 dimension) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
|
XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
|
||||||
int64 dimension) {
|
int64 dimension, bool is_stable) {
|
||||||
return operands[0].builder()->Sort(operands, comparator, dimension);
|
return operands[0].builder()->Sort(operands, comparator, dimension,
|
||||||
|
is_stable);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) {
|
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) {
|
||||||
|
@ -505,7 +505,7 @@ class XlaBuilder {
|
|||||||
XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
|
XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
|
||||||
int64 dimension = -1);
|
int64 dimension = -1);
|
||||||
XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
|
XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
|
||||||
int64 dimension = -1);
|
int64 dimension = -1, bool is_stable = false);
|
||||||
|
|
||||||
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
|
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
|
||||||
|
|
||||||
@ -923,7 +923,8 @@ class XlaBuilder {
|
|||||||
friend XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values,
|
friend XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values,
|
||||||
int64 dimension);
|
int64 dimension);
|
||||||
friend XlaOp Sort(absl::Span<const XlaOp> operands,
|
friend XlaOp Sort(absl::Span<const XlaOp> operands,
|
||||||
const XlaComputation& comparator, int64 dimension);
|
const XlaComputation& comparator, int64 dimension,
|
||||||
|
bool is_stable);
|
||||||
friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
|
friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
|
||||||
friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
|
friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
|
||||||
const XlaComputation& computation,
|
const XlaComputation& computation,
|
||||||
@ -1695,7 +1696,8 @@ XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
|
|||||||
int64 dimension = -1);
|
int64 dimension = -1);
|
||||||
|
|
||||||
// Enqueues a sort instruction onto the computation, using 'comparator' for
|
// Enqueues a sort instruction onto the computation, using 'comparator' for
|
||||||
// comparisons. 'comparator' needs to define a strict weak order.
|
// comparisons. 'comparator' needs to define a strict weak order. 'is_stable'
|
||||||
|
// determines whether the stable sorting should be used.
|
||||||
// If only one operand is provided:
|
// If only one operand is provided:
|
||||||
// * If the operand is a rank-1 tensor (an array), the result is a sorted array.
|
// * If the operand is a rank-1 tensor (an array), the result is a sorted array.
|
||||||
// The resulting sorting order has the property that for all index positions
|
// The resulting sorting order has the property that for all index positions
|
||||||
@ -1718,7 +1720,7 @@ XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
|
|||||||
// correspond to the value of operand i at two index positions.
|
// correspond to the value of operand i at two index positions.
|
||||||
// Default comparator computations can be found in lib/comparators.h
|
// Default comparator computations can be found in lib/comparators.h
|
||||||
XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
|
XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
|
||||||
int64 dimension = -1);
|
int64 dimension = -1, bool is_stable = false);
|
||||||
|
|
||||||
// Enqueues a clamp instruction onto the computation.
|
// Enqueues a clamp instruction onto the computation.
|
||||||
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
|
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
|
||||||
|
@ -77,7 +77,7 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ExecutableRunOptions& ExecutableRunOptions::set_device_assignment(
|
ExecutableRunOptions& ExecutableRunOptions::set_device_assignment(
|
||||||
DeviceAssignment* device_assignment) {
|
const DeviceAssignment* device_assignment) {
|
||||||
device_assignment_ = device_assignment;
|
device_assignment_ = device_assignment;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
@ -74,7 +74,7 @@ class ExecutableRunOptions {
|
|||||||
ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile);
|
ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile);
|
||||||
|
|
||||||
ExecutableRunOptions& set_device_assignment(
|
ExecutableRunOptions& set_device_assignment(
|
||||||
DeviceAssignment* device_assignment);
|
const DeviceAssignment* device_assignment);
|
||||||
const DeviceAssignment* device_assignment() const;
|
const DeviceAssignment* device_assignment() const;
|
||||||
|
|
||||||
ExecutableRunOptions& set_rng_seed(int rng_seed);
|
ExecutableRunOptions& set_rng_seed(int rng_seed);
|
||||||
@ -83,7 +83,7 @@ class ExecutableRunOptions {
|
|||||||
private:
|
private:
|
||||||
DeviceMemoryAllocator* allocator_ = nullptr;
|
DeviceMemoryAllocator* allocator_ = nullptr;
|
||||||
int device_ordinal_ = -1;
|
int device_ordinal_ = -1;
|
||||||
DeviceAssignment* device_assignment_ = nullptr;
|
const DeviceAssignment* device_assignment_ = nullptr;
|
||||||
stream_executor::Stream* stream_ = nullptr;
|
stream_executor::Stream* stream_ = nullptr;
|
||||||
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
|
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
|
||||||
ExecutionProfile* execution_profile_ = nullptr;
|
ExecutionProfile* execution_profile_ = nullptr;
|
||||||
|
@ -77,6 +77,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/client/lib:cholesky",
|
"//tensorflow/compiler/xla/client/lib:cholesky",
|
||||||
"//tensorflow/compiler/xla/client/lib:math",
|
"//tensorflow/compiler/xla/client/lib:math",
|
||||||
"//tensorflow/compiler/xla/client/lib:qr",
|
"//tensorflow/compiler/xla/client/lib:qr",
|
||||||
|
"//tensorflow/compiler/xla/service:computation_placer",
|
||||||
"//tensorflow/compiler/xla/service:platform_util",
|
"//tensorflow/compiler/xla/service:platform_util",
|
||||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||||
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
|
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
||||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
@ -53,74 +54,6 @@ namespace swig {
|
|||||||
// TODO(b/118641336): Factor out XRT parts into a small c++ library of their
|
// TODO(b/118641336): Factor out XRT parts into a small c++ library of their
|
||||||
// own.
|
// own.
|
||||||
|
|
||||||
// TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of
|
|
||||||
// device handles instead of needing to set the number of replicas at XLA
|
|
||||||
// service initialization time.
|
|
||||||
tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED);
|
|
||||||
int g_replica_count GUARDED_BY(g_local_client_mutex) = 1;
|
|
||||||
LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr;
|
|
||||||
|
|
||||||
string* GetPlatformNameString() {
|
|
||||||
static string* platform_name_string PT_GUARDED_BY(g_local_client_mutex) =
|
|
||||||
new string("Host");
|
|
||||||
return platform_name_string;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status InitializeReplicaCount(int replica_count) {
|
|
||||||
if (replica_count < 1) {
|
|
||||||
return InvalidArgument("Replica count must be >= 1; got %d.",
|
|
||||||
replica_count);
|
|
||||||
}
|
|
||||||
tensorflow::mutex_lock lock(g_local_client_mutex);
|
|
||||||
if (g_local_client != nullptr) {
|
|
||||||
return FailedPrecondition(
|
|
||||||
"Attempted to set the replica count to %d, but a local XLA service was "
|
|
||||||
"previously created with a replica count of %d.",
|
|
||||||
replica_count, g_replica_count);
|
|
||||||
}
|
|
||||||
g_replica_count = replica_count;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status InitializePlatformName(const string& platform_name) {
|
|
||||||
string* g_platform_name = GetPlatformNameString();
|
|
||||||
tensorflow::mutex_lock lock(g_local_client_mutex);
|
|
||||||
if (g_local_client != nullptr) {
|
|
||||||
return FailedPrecondition(
|
|
||||||
"Attempted to set the platform name to %s, but a local XLA service was "
|
|
||||||
"previously created with a platform name of %s.",
|
|
||||||
platform_name, *g_platform_name);
|
|
||||||
}
|
|
||||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
|
||||||
PlatformUtil::GetPlatform(platform_name));
|
|
||||||
if (platform->VisibleDeviceCount() <= 0) {
|
|
||||||
return InvalidArgument("Platform %s has no visible devices.",
|
|
||||||
platform_name);
|
|
||||||
}
|
|
||||||
*g_platform_name = platform_name;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
int GetReplicaCount() {
|
|
||||||
tensorflow::mutex_lock lock(g_local_client_mutex);
|
|
||||||
return g_replica_count;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<LocalClient*> GetOrCreateLocalClient() {
|
|
||||||
string* platform_name = GetPlatformNameString();
|
|
||||||
tensorflow::mutex_lock lock(g_local_client_mutex);
|
|
||||||
if (g_local_client != nullptr) {
|
|
||||||
return g_local_client;
|
|
||||||
}
|
|
||||||
LocalClientOptions options;
|
|
||||||
options.set_platform(PlatformUtil::GetPlatform(*platform_name).ValueOrDie());
|
|
||||||
options.set_number_of_replicas(g_replica_count);
|
|
||||||
TF_ASSIGN_OR_RETURN(g_local_client,
|
|
||||||
ClientLibrary::GetOrCreateLocalClient(options));
|
|
||||||
CHECK(g_local_client != nullptr);
|
|
||||||
return g_local_client;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) {
|
Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) {
|
||||||
const char* name = "xla._CPU_CUSTOM_CALL_TARGET";
|
const char* name = "xla._CPU_CUSTOM_CALL_TARGET";
|
||||||
if (!PyCapsule_IsValid(capsule, name)) {
|
if (!PyCapsule_IsValid(capsule, name)) {
|
||||||
@ -135,62 +68,66 @@ Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TransferToInfeedLocal(const Literal& literal) {
|
LocalClient::LocalClient(xla::LocalClient* client) : client_(client) {}
|
||||||
VLOG(1) << "Infeeding literal without replica number; shape: "
|
|
||||||
<< literal.shape();
|
/* static */ StatusOr<LocalClient> LocalClient::Get(
|
||||||
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient());
|
const string& platform_name) {
|
||||||
return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0);
|
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||||
|
PlatformUtil::GetPlatform(platform_name));
|
||||||
|
if (platform->VisibleDeviceCount() <= 0) {
|
||||||
|
return InvalidArgument("Platform %s has no visible devices.",
|
||||||
|
platform_name);
|
||||||
|
}
|
||||||
|
LocalClientOptions options;
|
||||||
|
options.set_platform(platform);
|
||||||
|
TF_ASSIGN_OR_RETURN(xla::LocalClient * client,
|
||||||
|
ClientLibrary::GetOrCreateLocalClient(options));
|
||||||
|
CHECK(client != nullptr);
|
||||||
|
return LocalClient(client);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TransferToInfeedLocalReplica(const Literal& literal,
|
// Returns the number of devices known to the XLA client.
|
||||||
int replica_number) {
|
int LocalClient::DeviceCount() const { return client_->device_count(); }
|
||||||
VLOG(1) << "Infeeding shape " << literal.shape()
|
|
||||||
<< " to replica number: " << replica_number;
|
Status LocalClient::TransferToInfeed(const Literal& literal,
|
||||||
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient());
|
int device_ordinal) {
|
||||||
TF_ASSIGN_OR_RETURN(int device_ordinal,
|
VLOG(1) << "Infeeding literal to device " << device_ordinal
|
||||||
client->ReplicaNumberToDeviceOrdinal(replica_number));
|
<< "; shape: " << literal.shape();
|
||||||
return client->TransferToInfeedLocal(literal, device_ordinal);
|
return client_->TransferToInfeed(literal, device_ordinal);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
|
StatusOr<Literal> LocalClient::TransferFromOutfeed(const Shape& shape,
|
||||||
int replica_number) {
|
int device_ordinal) {
|
||||||
VLOG(1) << "Outfeeding literal from replica number: " << replica_number
|
VLOG(1) << "Outfeeding literal from device " << device_ordinal
|
||||||
<< " shape: " << shape;
|
<< "; shape: " << shape;
|
||||||
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient());
|
return client_->TransferFromOutfeed(&shape, device_ordinal);
|
||||||
TF_ASSIGN_OR_RETURN(int device_ordinal,
|
|
||||||
client->ReplicaNumberToDeviceOrdinal(replica_number));
|
|
||||||
return client->TransferFromOutfeedLocal(shape, device_ordinal);
|
|
||||||
}
|
|
||||||
|
|
||||||
static StatusOr<ScopedShapedBuffer> ToBuffer(LocalClient* client,
|
|
||||||
int device_ordinal,
|
|
||||||
const Literal& arg) {
|
|
||||||
return client->LiteralToShapedBuffer(arg, device_ordinal,
|
|
||||||
client->backend().memory_allocator());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
|
StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
|
||||||
const Literal& argument, const absl::optional<Shape>& shape_with_layout,
|
const Literal& argument, const absl::optional<Shape>& shape_with_layout,
|
||||||
int replica_number) {
|
const LocalClient& client, int device_ordinal) {
|
||||||
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient());
|
VLOG(1) << "Creating shaped buffer from literal on device ordinal: "
|
||||||
TF_ASSIGN_OR_RETURN(int device_ordinal,
|
<< device_ordinal;
|
||||||
client->ReplicaNumberToDeviceOrdinal(replica_number));
|
auto literal_to_buffer = [&](const Literal& arg) {
|
||||||
VLOG(1) << "Creating shaped buffer from literal on replica/ordinal: "
|
return client.client()->LiteralToShapedBuffer(
|
||||||
<< replica_number << "/" << device_ordinal;
|
arg, device_ordinal, client.client()->backend().memory_allocator());
|
||||||
|
};
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> buf = [&] {
|
StatusOr<ScopedShapedBuffer> buf = [&] {
|
||||||
if (shape_with_layout) {
|
if (shape_with_layout) {
|
||||||
Literal relaid = argument.Relayout(shape_with_layout.value());
|
Literal relaid = argument.Relayout(shape_with_layout.value());
|
||||||
return ToBuffer(client, device_ordinal, relaid);
|
return literal_to_buffer(relaid);
|
||||||
}
|
}
|
||||||
return ToBuffer(client, device_ordinal, argument);
|
return literal_to_buffer(argument);
|
||||||
}();
|
}();
|
||||||
TF_RETURN_IF_ERROR(buf.status());
|
TF_RETURN_IF_ERROR(buf.status());
|
||||||
return new LocalShapedBuffer(std::move(buf).ValueOrDie());
|
return new LocalShapedBuffer(std::move(buf).ValueOrDie(), client.client());
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer)
|
LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer,
|
||||||
: shaped_buffer_(std::move(shaped_buffer)) {}
|
xla::LocalClient* client)
|
||||||
|
: shaped_buffer_(std::move(shaped_buffer)), client_(client) {}
|
||||||
|
|
||||||
const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const {
|
const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const {
|
||||||
return &shaped_buffer_;
|
return &shaped_buffer_;
|
||||||
@ -203,8 +140,7 @@ const Shape& LocalShapedBuffer::shape() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<Literal> LocalShapedBuffer::ToLiteral() const {
|
StatusOr<Literal> LocalShapedBuffer::ToLiteral() const {
|
||||||
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient());
|
return client_->ShapedBufferToLiteral(*shaped_buffer());
|
||||||
return client->ShapedBufferToLiteral(*shaped_buffer());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalShapedBufferTuple::LocalShapedBufferTuple(
|
LocalShapedBufferTuple::LocalShapedBufferTuple(
|
||||||
@ -235,6 +171,51 @@ StatusOr<LocalShapedBuffer*> LocalShapedBufferTuple::Release(int i) {
|
|||||||
|
|
||||||
int64 LocalShapedBufferTuple::size() const { return elements_.size(); }
|
int64 LocalShapedBufferTuple::size() const { return elements_.size(); }
|
||||||
|
|
||||||
|
StatusOr<LocalShapedBufferTuple*> LocalShapedBuffer::DestructureTuple() {
|
||||||
|
const Shape tuple_shape = shape();
|
||||||
|
|
||||||
|
if (!tuple_shape.IsTuple()) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"Attemped to destructure a LocalShapedBuffer that did not have a tuple "
|
||||||
|
"shape; shape: %s",
|
||||||
|
ShapeUtil::HumanString(tuple_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceMemoryAllocator* allocator = shaped_buffer()->memory_allocator();
|
||||||
|
ShapedBuffer tuple_buffer = Release();
|
||||||
|
|
||||||
|
// Extract some metadata we use to construct scoped buffers.
|
||||||
|
const se::Platform* platform = tuple_buffer.platform();
|
||||||
|
int device_ordinal = tuple_buffer.device_ordinal();
|
||||||
|
|
||||||
|
ShapeTree<se::DeviceMemoryBase>& shape_tree = tuple_buffer.buffers();
|
||||||
|
std::vector<LocalShapedBuffer*> results;
|
||||||
|
for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
|
||||||
|
// Create a shaped buffer for this destructured tuple element.
|
||||||
|
const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i});
|
||||||
|
VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape;
|
||||||
|
ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal);
|
||||||
|
|
||||||
|
ShapeUtil::ForEachSubshape(
|
||||||
|
subshape, [&](const Shape& s, const ShapeIndex& index) {
|
||||||
|
ShapeIndex original(index);
|
||||||
|
original.push_front(i);
|
||||||
|
se::DeviceMemoryBase* device_memory =
|
||||||
|
shape_tree.mutable_element(original);
|
||||||
|
shaped_buffer.set_buffer(*device_memory, index);
|
||||||
|
*device_memory = se::DeviceMemoryBase();
|
||||||
|
});
|
||||||
|
|
||||||
|
VLOG(3) << "Completed tuple element: " << i;
|
||||||
|
results.push_back(new LocalShapedBuffer(
|
||||||
|
ScopedShapedBuffer(std::move(shaped_buffer), allocator), client_));
|
||||||
|
}
|
||||||
|
// Deallocate the root buffer.
|
||||||
|
se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer();
|
||||||
|
TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer));
|
||||||
|
return new LocalShapedBufferTuple(std::move(results));
|
||||||
|
}
|
||||||
|
|
||||||
XrtAllocation::XrtAllocation(int64 handle, Shape shape,
|
XrtAllocation::XrtAllocation(int64 handle, Shape shape,
|
||||||
const string& session_target)
|
const string& session_target)
|
||||||
: handle_(handle), shape_(shape), session_target_(session_target) {}
|
: handle_(handle), shape_(shape), session_target_(session_target) {}
|
||||||
@ -332,23 +313,32 @@ StatusOr<XrtAllocation*> XrtAllocationTuple::Release(int i) {
|
|||||||
|
|
||||||
int64 XrtAllocationTuple::size() const { return elements_.size(); }
|
int64 XrtAllocationTuple::size() const { return elements_.size(); }
|
||||||
|
|
||||||
CompiledLocalComputation::CompiledLocalComputation(
|
LocalExecutable::LocalExecutable(
|
||||||
std::unique_ptr<LocalExecutable> executable)
|
std::unique_ptr<xla::LocalExecutable> executable,
|
||||||
: executable_(std::move(executable)) {}
|
xla::DeviceAssignment device_assignment, xla::LocalClient* client)
|
||||||
|
: executable_(std::move(executable)),
|
||||||
|
device_assignment_(std::move(device_assignment)),
|
||||||
|
client_(client) {}
|
||||||
|
|
||||||
StatusOr<LocalShapedBuffer*> CompiledLocalComputation::Execute(
|
std::vector<int> LocalExecutable::DeviceOrdinals() const {
|
||||||
|
int num_replicas = device_assignment_.replica_count();
|
||||||
|
std::vector<int> device_ordinals;
|
||||||
|
device_ordinals.reserve(num_replicas);
|
||||||
|
for (int i = 0; i < num_replicas; ++i) {
|
||||||
|
device_ordinals.push_back(device_assignment_(i, 0));
|
||||||
|
}
|
||||||
|
return device_ordinals;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<LocalShapedBuffer*> LocalExecutable::Execute(
|
||||||
absl::Span<LocalShapedBuffer* const> argument_handles) {
|
absl::Span<LocalShapedBuffer* const> argument_handles) {
|
||||||
if (num_replicas() != 1) {
|
if (num_replicas() != 1) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Attempted to execute computation with %d replicas using Execute()",
|
"Attempted to execute computation with %d replicas using Execute()",
|
||||||
num_replicas());
|
num_replicas());
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient());
|
|
||||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
|
||||||
client->backend().computation_placer()->AssignDevices(
|
|
||||||
1, /*computation_count=*/1));
|
|
||||||
StatusOr<ScopedShapedBuffer> result_buffer_status;
|
StatusOr<ScopedShapedBuffer> result_buffer_status;
|
||||||
const int device_ordinal = device_assignment(0, 0);
|
const int device_ordinal = device_assignment_(0, 0);
|
||||||
VLOG(3) << "Replica 0 mapped to device ordinal for execution: "
|
VLOG(3) << "Replica 0 mapped to device ordinal for execution: "
|
||||||
<< device_ordinal;
|
<< device_ordinal;
|
||||||
|
|
||||||
@ -360,10 +350,10 @@ StatusOr<LocalShapedBuffer*> CompiledLocalComputation::Execute(
|
|||||||
|
|
||||||
ExecutableRunOptions options;
|
ExecutableRunOptions options;
|
||||||
options.set_device_ordinal(device_ordinal);
|
options.set_device_ordinal(device_ordinal);
|
||||||
options.set_allocator(client->backend().memory_allocator());
|
options.set_allocator(client_->backend().memory_allocator());
|
||||||
options.set_intra_op_thread_pool(
|
options.set_intra_op_thread_pool(
|
||||||
client->backend().eigen_intra_op_thread_pool_device());
|
client_->backend().eigen_intra_op_thread_pool_device());
|
||||||
options.set_device_assignment(&device_assignment);
|
options.set_device_assignment(&device_assignment_);
|
||||||
|
|
||||||
result_buffer_status = executable_->Run(argument_buffers, options);
|
result_buffer_status = executable_->Run(argument_buffers, options);
|
||||||
|
|
||||||
@ -373,13 +363,13 @@ StatusOr<LocalShapedBuffer*> CompiledLocalComputation::Execute(
|
|||||||
"%s.",
|
"%s.",
|
||||||
result_buffer_status.status().ToString());
|
result_buffer_status.status().ToString());
|
||||||
}
|
}
|
||||||
return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie());
|
return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(),
|
||||||
|
client_);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<LocalShapedBufferTuple*> CompiledLocalComputation::ExecutePerReplica(
|
StatusOr<LocalShapedBufferTuple*> LocalExecutable::ExecutePerReplica(
|
||||||
absl::Span<const std::vector<LocalShapedBuffer*>> argument_handles) {
|
absl::Span<const std::vector<LocalShapedBuffer*>> argument_handles) {
|
||||||
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient());
|
const int num_devices = client_->device_count();
|
||||||
const int num_devices = client->device_count();
|
|
||||||
|
|
||||||
if (argument_handles.size() != num_replicas()) {
|
if (argument_handles.size() != num_replicas()) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
@ -394,14 +384,9 @@ StatusOr<LocalShapedBufferTuple*> CompiledLocalComputation::ExecutePerReplica(
|
|||||||
|
|
||||||
VLOG(1) << "Executing with " << num_replicas() << " replicas.";
|
VLOG(1) << "Executing with " << num_replicas() << " replicas.";
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
|
||||||
client->backend().computation_placer()->AssignDevices(
|
|
||||||
num_replicas(), /*computation_count=*/1));
|
|
||||||
|
|
||||||
std::vector<StatusOr<ScopedShapedBuffer>> results(num_replicas());
|
std::vector<StatusOr<ScopedShapedBuffer>> results(num_replicas());
|
||||||
auto execute = [this, client, &device_assignment, &argument_handles,
|
auto execute = [this, &argument_handles, &results](int replica) {
|
||||||
&results](int replica) {
|
const int device_ordinal = device_assignment_(replica, 0);
|
||||||
const int device_ordinal = device_assignment(replica, 0);
|
|
||||||
VLOG(3) << "Replica " << replica
|
VLOG(3) << "Replica " << replica
|
||||||
<< " mapped to device ordinal for execution: " << device_ordinal;
|
<< " mapped to device ordinal for execution: " << device_ordinal;
|
||||||
|
|
||||||
@ -413,10 +398,10 @@ StatusOr<LocalShapedBufferTuple*> CompiledLocalComputation::ExecutePerReplica(
|
|||||||
|
|
||||||
ExecutableRunOptions options;
|
ExecutableRunOptions options;
|
||||||
options.set_device_ordinal(device_ordinal);
|
options.set_device_ordinal(device_ordinal);
|
||||||
options.set_allocator(client->backend().memory_allocator());
|
options.set_allocator(client_->backend().memory_allocator());
|
||||||
options.set_intra_op_thread_pool(
|
options.set_intra_op_thread_pool(
|
||||||
client->backend().eigen_intra_op_thread_pool_device());
|
client_->backend().eigen_intra_op_thread_pool_device());
|
||||||
options.set_device_assignment(&device_assignment);
|
options.set_device_assignment(&device_assignment_);
|
||||||
StatusOr<ScopedShapedBuffer> result_buffer_status =
|
StatusOr<ScopedShapedBuffer> result_buffer_status =
|
||||||
executable_->Run(argument_buffers, options);
|
executable_->Run(argument_buffers, options);
|
||||||
|
|
||||||
@ -448,26 +433,19 @@ StatusOr<LocalShapedBufferTuple*> CompiledLocalComputation::ExecutePerReplica(
|
|||||||
replica, statusor.status().ToString());
|
replica, statusor.status().ToString());
|
||||||
}
|
}
|
||||||
wrapped_results[replica] =
|
wrapped_results[replica] =
|
||||||
new LocalShapedBuffer(std::move(statusor).ValueOrDie());
|
new LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new LocalShapedBufferTuple(std::move(wrapped_results));
|
return new LocalShapedBufferTuple(std::move(wrapped_results));
|
||||||
}
|
}
|
||||||
|
|
||||||
static StatusOr<Shape> GetReturnValueShape(const XlaComputation& computation) {
|
XrtExecutable::XrtExecutable(const ProgramShape& program_shape, int64 handle,
|
||||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
const string& session_target)
|
||||||
computation.GetProgramShape());
|
|
||||||
return std::move(*program_shape.mutable_result());
|
|
||||||
}
|
|
||||||
|
|
||||||
CompiledXrtComputation::CompiledXrtComputation(
|
|
||||||
const ProgramShape& program_shape, int64 handle,
|
|
||||||
const string& session_target)
|
|
||||||
: program_shape_(program_shape),
|
: program_shape_(program_shape),
|
||||||
handle_(handle),
|
handle_(handle),
|
||||||
session_target_(session_target) {}
|
session_target_(session_target) {}
|
||||||
|
|
||||||
CompiledXrtComputation::~CompiledXrtComputation() {
|
XrtExecutable::~XrtExecutable() {
|
||||||
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
|
||||||
auto computation_handle =
|
auto computation_handle =
|
||||||
tensorflow::ops::Placeholder(root, tensorflow::DT_INT64);
|
tensorflow::ops::Placeholder(root, tensorflow::DT_INT64);
|
||||||
@ -489,7 +467,7 @@ CompiledXrtComputation::~CompiledXrtComputation() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<XrtAllocation*> CompiledXrtComputation::Execute(
|
StatusOr<XrtAllocation*> XrtExecutable::Execute(
|
||||||
absl::Span<XrtAllocation* const> argument_handles) {
|
absl::Span<XrtAllocation* const> argument_handles) {
|
||||||
const int num_expected_arguments = program_shape().parameters().size();
|
const int num_expected_arguments = program_shape().parameters().size();
|
||||||
|
|
||||||
@ -528,36 +506,41 @@ StatusOr<XrtAllocation*> CompiledXrtComputation::Execute(
|
|||||||
return new XrtAllocation(output, program_shape().result(), session_target_);
|
return new XrtAllocation(output, program_shape().result(), session_target_);
|
||||||
}
|
}
|
||||||
|
|
||||||
const ProgramShape& CompiledXrtComputation::program_shape() const {
|
const ProgramShape& XrtExecutable::program_shape() const {
|
||||||
return program_shape_;
|
return program_shape_;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 CompiledXrtComputation::handle() const { return handle_; }
|
int64 XrtExecutable::handle() const { return handle_; }
|
||||||
|
|
||||||
LocalComputation::LocalComputation(XlaComputation computation)
|
Computation::Computation(XlaComputation computation)
|
||||||
: computation_(std::move(computation)) {}
|
: computation_(std::move(computation)) {}
|
||||||
|
|
||||||
StatusOr<CompiledLocalComputation*> LocalComputation::Compile(
|
StatusOr<LocalExecutable*> Computation::Compile(
|
||||||
const std::vector<Shape>& argument_shapes,
|
const std::vector<Shape>& argument_shapes,
|
||||||
const ExecutableBuildOptions* build_options) {
|
const ExecutableBuildOptions* build_options, const LocalClient& client) {
|
||||||
std::vector<const Shape*> argument_shape_pointers;
|
std::vector<const Shape*> argument_shape_pointers;
|
||||||
argument_shape_pointers.reserve(argument_shapes.size());
|
argument_shape_pointers.reserve(argument_shapes.size());
|
||||||
for (auto& argument_shape : argument_shapes) {
|
for (auto& argument_shape : argument_shapes) {
|
||||||
argument_shape_pointers.push_back(&argument_shape);
|
argument_shape_pointers.push_back(&argument_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient());
|
|
||||||
ExecutableBuildOptions options;
|
ExecutableBuildOptions options;
|
||||||
if (build_options != nullptr) {
|
if (build_options != nullptr) {
|
||||||
options = *build_options;
|
options = *build_options;
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto local_executable,
|
auto local_executable,
|
||||||
client->Compile(computation_, argument_shape_pointers, options));
|
client.client()->Compile(computation_, argument_shape_pointers, options));
|
||||||
return new CompiledLocalComputation(std::move(local_executable));
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
DeviceAssignment device_assignment,
|
||||||
|
client.client()->backend().computation_placer()->AssignDevices(
|
||||||
|
options.num_replicas(), /*computation_count=*/1));
|
||||||
|
|
||||||
|
return new LocalExecutable(std::move(local_executable),
|
||||||
|
std::move(device_assignment), client.client());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<CompiledXrtComputation*> LocalComputation::CompileForXrt(
|
StatusOr<XrtExecutable*> Computation::CompileForXrt(
|
||||||
const std::vector<Shape>& argument_shapes, const string& session_target) {
|
const std::vector<Shape>& argument_shapes, const string& session_target) {
|
||||||
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
|
||||||
auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING);
|
auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING);
|
||||||
@ -585,14 +568,12 @@ StatusOr<CompiledXrtComputation*> LocalComputation::CompileForXrt(
|
|||||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||||
computation().GetProgramShape());
|
computation().GetProgramShape());
|
||||||
int64 handle = outputs[0].scalar<int64>()();
|
int64 handle = outputs[0].scalar<int64>()();
|
||||||
return new CompiledXrtComputation(program_shape, handle, session_target);
|
return new XrtExecutable(program_shape, handle, session_target);
|
||||||
}
|
}
|
||||||
|
|
||||||
const XlaComputation& LocalComputation::computation() const {
|
const XlaComputation& Computation::computation() const { return computation_; }
|
||||||
return computation_;
|
|
||||||
}
|
|
||||||
|
|
||||||
string LocalComputation::GetSerializedProto() const {
|
string Computation::GetSerializedProto() const {
|
||||||
string result;
|
string result;
|
||||||
if (!computation_.proto().SerializeToString(&result)) {
|
if (!computation_.proto().SerializeToString(&result)) {
|
||||||
LOG(ERROR) << "Failed to serialize the HloModuleProto.";
|
LOG(ERROR) << "Failed to serialize the HloModuleProto.";
|
||||||
@ -601,101 +582,103 @@ string LocalComputation::GetSerializedProto() const {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<Shape> LocalComputation::GetReturnValueShape() const {
|
StatusOr<ProgramShape> Computation::GetProgramShape() const {
|
||||||
return swig::GetReturnValueShape(computation_);
|
return computation_.GetProgramShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<Shape> Computation::GetReturnValueShape() const {
|
||||||
|
TF_ASSIGN_OR_RETURN(ProgramShape shape, computation_.GetProgramShape());
|
||||||
|
return std::move(*shape.mutable_result());
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp::LocalOp(const XlaOp& op) : op_(op) {}
|
LocalOp::LocalOp(const XlaOp& op) : op_(op) {}
|
||||||
|
|
||||||
const XlaOp& LocalOp::op() const { return op_; }
|
const XlaOp& LocalOp::op() const { return op_; }
|
||||||
|
|
||||||
LocalComputationBuilder::LocalComputationBuilder(const string& computation_name)
|
ComputationBuilder::ComputationBuilder(const string& computation_name)
|
||||||
: builder_(computation_name) {}
|
: builder_(computation_name) {}
|
||||||
|
|
||||||
void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) {
|
void ComputationBuilder::SetOpMetadata(const OpMetadata& metadata) {
|
||||||
builder_.SetOpMetadata(metadata);
|
builder_.SetOpMetadata(metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); }
|
void ComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); }
|
||||||
|
|
||||||
StatusOr<LocalComputation*> LocalComputationBuilder::Build() {
|
StatusOr<Computation*> ComputationBuilder::Build() {
|
||||||
TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build());
|
TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build());
|
||||||
return new LocalComputation(std::move(computation));
|
return new Computation(std::move(computation));
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Parameter(int64 parameter_number,
|
LocalOp ComputationBuilder::Parameter(int64 parameter_number,
|
||||||
const Shape& shape,
|
const Shape& shape, const string& name) {
|
||||||
const string& name) {
|
|
||||||
return xla::Parameter(&builder_, parameter_number, shape, name);
|
return xla::Parameter(&builder_, parameter_number, shape, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<LocalComputation*> LocalComputationBuilder::BuildWithRoot(
|
StatusOr<Computation*> ComputationBuilder::BuildWithRoot(const LocalOp& root) {
|
||||||
const LocalOp& root) {
|
|
||||||
TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build(root.op()));
|
TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build(root.op()));
|
||||||
return new LocalComputation(std::move(computation));
|
return new Computation(std::move(computation));
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<Shape> LocalComputationBuilder::GetShape(const LocalOp& operand) {
|
StatusOr<Shape> ComputationBuilder::GetShape(const LocalOp& operand) {
|
||||||
return builder_.GetShape(operand.op());
|
return builder_.GetShape(operand.op());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<Shape> LocalComputationBuilder::GetReturnValueShape() {
|
StatusOr<Shape> ComputationBuilder::GetReturnValueShape() {
|
||||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape());
|
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape());
|
||||||
return program_shape.result();
|
return program_shape.result();
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Infeed(const Shape& shape) {
|
LocalOp ComputationBuilder::Infeed(const Shape& shape) {
|
||||||
return xla::Infeed(&builder_, shape);
|
return xla::Infeed(&builder_, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LocalComputationBuilder::Outfeed(const LocalOp& operand,
|
void ComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape,
|
||||||
const Shape& shape,
|
const string& outfeed_config) {
|
||||||
const string& outfeed_config) {
|
|
||||||
xla::Outfeed(operand.op(), shape, outfeed_config);
|
xla::Outfeed(operand.op(), shape, outfeed_config);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) {
|
LocalOp ComputationBuilder::ConstantLiteral(const Literal& literal) {
|
||||||
return xla::ConstantLiteral(&builder_, literal);
|
return xla::ConstantLiteral(&builder_, literal);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Iota(PrimitiveType element_type, int64 size) {
|
LocalOp ComputationBuilder::Iota(PrimitiveType element_type, int64 size) {
|
||||||
return xla::Iota(&builder_, element_type, size);
|
return xla::Iota(&builder_, element_type, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::BroadcastedIota(const Shape& shape,
|
LocalOp ComputationBuilder::BroadcastedIota(const Shape& shape,
|
||||||
int64 dimension) {
|
int64 dimension) {
|
||||||
return xla::Iota(&builder_, shape, dimension);
|
return xla::Iota(&builder_, shape, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Broadcast(
|
LocalOp ComputationBuilder::Broadcast(const LocalOp& operand,
|
||||||
const LocalOp& operand, absl::Span<const int64> broadcast_sizes) {
|
absl::Span<const int64> broadcast_sizes) {
|
||||||
return xla::Broadcast(operand.op(), broadcast_sizes);
|
return xla::Broadcast(operand.op(), broadcast_sizes);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::BroadcastInDim(
|
LocalOp ComputationBuilder::BroadcastInDim(
|
||||||
const LocalOp& operand, absl::Span<const int64> out_dim_sizes,
|
const LocalOp& operand, absl::Span<const int64> out_dim_sizes,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions);
|
return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Pad(const LocalOp& operand,
|
LocalOp ComputationBuilder::Pad(const LocalOp& operand,
|
||||||
const LocalOp& padding_value,
|
const LocalOp& padding_value,
|
||||||
const PaddingConfig& padding_config) {
|
const PaddingConfig& padding_config) {
|
||||||
return xla::Pad(operand.op(), padding_value.op(), padding_config);
|
return xla::Pad(operand.op(), padding_value.op(), padding_config);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Reshape(const LocalOp& operand,
|
LocalOp ComputationBuilder::Reshape(const LocalOp& operand,
|
||||||
absl::Span<const int64> dimensions,
|
absl::Span<const int64> dimensions,
|
||||||
absl::Span<const int64> new_sizes) {
|
absl::Span<const int64> new_sizes) {
|
||||||
return xla::Reshape(operand.op(), dimensions, new_sizes);
|
return xla::Reshape(operand.op(), dimensions, new_sizes);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand,
|
LocalOp ComputationBuilder::Collapse(const LocalOp& operand,
|
||||||
absl::Span<const int64> dimensions) {
|
absl::Span<const int64> dimensions) {
|
||||||
return xla::Collapse(operand.op(), dimensions);
|
return xla::Collapse(operand.op(), dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::AllToAll(
|
LocalOp ComputationBuilder::AllToAll(
|
||||||
const LocalOp& operand, int64 split_dimension, int64 concat_dimension,
|
const LocalOp& operand, int64 split_dimension, int64 concat_dimension,
|
||||||
int64 split_count, absl::Span<const ReplicaGroup> replica_groups) {
|
int64 split_count, absl::Span<const ReplicaGroup> replica_groups) {
|
||||||
std::vector<ReplicaGroup> rg(replica_groups.size());
|
std::vector<ReplicaGroup> rg(replica_groups.size());
|
||||||
@ -706,39 +689,38 @@ LocalOp LocalComputationBuilder::AllToAll(
|
|||||||
split_count, rg);
|
split_count, rg);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::CrossReplicaSum(
|
LocalOp ComputationBuilder::CrossReplicaSum(
|
||||||
const LocalOp& operand, absl::Span<const ReplicaGroup> replica_groups) {
|
const LocalOp& operand, absl::Span<const ReplicaGroup> replica_groups) {
|
||||||
return xla::CrossReplicaSum(operand.op(), replica_groups);
|
return xla::CrossReplicaSum(operand.op(), replica_groups);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Slice(const LocalOp& operand,
|
LocalOp ComputationBuilder::Slice(const LocalOp& operand,
|
||||||
absl::Span<const int64> start_indices,
|
absl::Span<const int64> start_indices,
|
||||||
absl::Span<const int64> limit_indices,
|
absl::Span<const int64> limit_indices,
|
||||||
absl::Span<const int64> strides) {
|
absl::Span<const int64> strides) {
|
||||||
return xla::Slice(operand.op(), start_indices, limit_indices, strides);
|
return xla::Slice(operand.op(), start_indices, limit_indices, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand,
|
LocalOp ComputationBuilder::SliceInDim(const LocalOp& operand,
|
||||||
int64 start_index,
|
int64 start_index, int64 limit_index,
|
||||||
int64 limit_index, int64 stride,
|
int64 stride, int64 dimno) {
|
||||||
int64 dimno) {
|
|
||||||
return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno);
|
return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::DynamicSlice(
|
LocalOp ComputationBuilder::DynamicSlice(const LocalOp& operand,
|
||||||
const LocalOp& operand, const LocalOp& start_indices,
|
const LocalOp& start_indices,
|
||||||
absl::Span<const int64> slice_sizes) {
|
absl::Span<const int64> slice_sizes) {
|
||||||
return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes);
|
return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::DynamicUpdateSlice(
|
LocalOp ComputationBuilder::DynamicUpdateSlice(const LocalOp& operand,
|
||||||
const LocalOp& operand, const LocalOp& update,
|
const LocalOp& update,
|
||||||
const LocalOp& start_indices) {
|
const LocalOp& start_indices) {
|
||||||
return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op());
|
return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op());
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::ConcatInDim(absl::Span<const LocalOp> operands,
|
LocalOp ComputationBuilder::ConcatInDim(absl::Span<const LocalOp> operands,
|
||||||
int64 dimension) {
|
int64 dimension) {
|
||||||
std::vector<XlaOp> xla_ops;
|
std::vector<XlaOp> xla_ops;
|
||||||
xla_ops.reserve(operands.size());
|
xla_ops.reserve(operands.size());
|
||||||
for (const auto& op : operands) {
|
for (const auto& op : operands) {
|
||||||
@ -747,18 +729,18 @@ LocalOp LocalComputationBuilder::ConcatInDim(absl::Span<const LocalOp> operands,
|
|||||||
return xla::ConcatInDim(&builder_, xla_ops, dimension);
|
return xla::ConcatInDim(&builder_, xla_ops, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding(
|
LocalOp ComputationBuilder::SelectAndScatterWithGeneralPadding(
|
||||||
const LocalOp& operand, const LocalComputation& select,
|
const LocalOp& operand, const Computation& select,
|
||||||
absl::Span<const int64> window_dimensions,
|
absl::Span<const int64> window_dimensions,
|
||||||
absl::Span<const int64> window_strides,
|
absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding, const LocalOp& source,
|
absl::Span<const std::pair<int64, int64>> padding, const LocalOp& source,
|
||||||
const LocalOp& init_value, const LocalComputation& scatter) {
|
const LocalOp& init_value, const Computation& scatter) {
|
||||||
return xla::SelectAndScatterWithGeneralPadding(
|
return xla::SelectAndScatterWithGeneralPadding(
|
||||||
operand.op(), select.computation(), window_dimensions, window_strides,
|
operand.op(), select.computation(), window_dimensions, window_strides,
|
||||||
padding, source.op(), init_value.op(), scatter.computation());
|
padding, source.op(), init_value.op(), scatter.computation());
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Tuple(absl::Span<const LocalOp> elements) {
|
LocalOp ComputationBuilder::Tuple(absl::Span<const LocalOp> elements) {
|
||||||
std::vector<XlaOp> xla_ops;
|
std::vector<XlaOp> xla_ops;
|
||||||
xla_ops.reserve(elements.size());
|
xla_ops.reserve(elements.size());
|
||||||
for (const auto& op : elements) {
|
for (const auto& op : elements) {
|
||||||
@ -768,22 +750,22 @@ LocalOp LocalComputationBuilder::Tuple(absl::Span<const LocalOp> elements) {
|
|||||||
return xla::Tuple(&builder_, xla_ops);
|
return xla::Tuple(&builder_, xla_ops);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data,
|
LocalOp ComputationBuilder::GetTupleElement(const LocalOp& tuple_data,
|
||||||
int64 index) {
|
int64 index) {
|
||||||
return xla::GetTupleElement(tuple_data.op(), index);
|
return xla::GetTupleElement(tuple_data.op(), index);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) {
|
LocalOp ComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) {
|
||||||
return xla::Dot(lhs.op(), rhs.op());
|
return xla::Dot(lhs.op(), rhs.op());
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::DotGeneral(
|
LocalOp ComputationBuilder::DotGeneral(
|
||||||
const LocalOp& lhs, const LocalOp& rhs,
|
const LocalOp& lhs, const LocalOp& rhs,
|
||||||
const DotDimensionNumbers& dimension_numbers) {
|
const DotDimensionNumbers& dimension_numbers) {
|
||||||
return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers);
|
return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::ConvGeneralDilated(
|
LocalOp ComputationBuilder::ConvGeneralDilated(
|
||||||
const LocalOp& lhs, const LocalOp& rhs,
|
const LocalOp& lhs, const LocalOp& rhs,
|
||||||
absl::Span<const int64> window_strides,
|
absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
@ -795,18 +777,18 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated(
|
|||||||
feature_group_count);
|
feature_group_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::ConvertElementType(
|
LocalOp ComputationBuilder::ConvertElementType(const LocalOp& operand,
|
||||||
const LocalOp& operand, PrimitiveType new_element_type) {
|
PrimitiveType new_element_type) {
|
||||||
return xla::ConvertElementType(operand.op(), new_element_type);
|
return xla::ConvertElementType(operand.op(), new_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::BitcastConvertType(
|
LocalOp ComputationBuilder::BitcastConvertType(const LocalOp& operand,
|
||||||
const LocalOp& operand, PrimitiveType new_element_type) {
|
PrimitiveType new_element_type) {
|
||||||
return xla::BitcastConvertType(operand.op(), new_element_type);
|
return xla::BitcastConvertType(operand.op(), new_element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation,
|
LocalOp ComputationBuilder::Call(const Computation& local_computation,
|
||||||
absl::Span<const LocalOp> operands) {
|
absl::Span<const LocalOp> operands) {
|
||||||
std::vector<XlaOp> xla_ops;
|
std::vector<XlaOp> xla_ops;
|
||||||
xla_ops.reserve(operands.size());
|
xla_ops.reserve(operands.size());
|
||||||
for (const auto& op : operands) {
|
for (const auto& op : operands) {
|
||||||
@ -815,7 +797,7 @@ LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation,
|
|||||||
return xla::Call(&builder_, local_computation.computation(), xla_ops);
|
return xla::Call(&builder_, local_computation.computation(), xla_ops);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::CustomCall(
|
LocalOp ComputationBuilder::CustomCall(
|
||||||
const string& call_target_name, absl::Span<const LocalOp> operands,
|
const string& call_target_name, absl::Span<const LocalOp> operands,
|
||||||
const Shape& shape_with_layout,
|
const Shape& shape_with_layout,
|
||||||
const std::vector<Shape>& operand_shapes_with_layout,
|
const std::vector<Shape>& operand_shapes_with_layout,
|
||||||
@ -830,19 +812,19 @@ LocalOp LocalComputationBuilder::CustomCall(
|
|||||||
operand_shapes_with_layout, opaque);
|
operand_shapes_with_layout, opaque);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Transpose(
|
LocalOp ComputationBuilder::Transpose(const LocalOp& operand,
|
||||||
const LocalOp& operand, absl::Span<const int64> permutation) {
|
absl::Span<const int64> permutation) {
|
||||||
return xla::Transpose(operand.op(), permutation);
|
return xla::Transpose(operand.op(), permutation);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Rev(const LocalOp& operand,
|
LocalOp ComputationBuilder::Rev(const LocalOp& operand,
|
||||||
absl::Span<const int64> dimensions) {
|
absl::Span<const int64> dimensions) {
|
||||||
return xla::Rev(operand.op(), dimensions);
|
return xla::Rev(operand.op(), dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Map(absl::Span<const LocalOp> operands,
|
LocalOp ComputationBuilder::Map(absl::Span<const LocalOp> operands,
|
||||||
const LocalComputation& local_computation,
|
const Computation& local_computation,
|
||||||
absl::Span<const int64> dimensions) {
|
absl::Span<const int64> dimensions) {
|
||||||
std::vector<XlaOp> xla_ops;
|
std::vector<XlaOp> xla_ops;
|
||||||
xla_ops.reserve(operands.size());
|
xla_ops.reserve(operands.size());
|
||||||
for (const auto& op : operands) {
|
for (const auto& op : operands) {
|
||||||
@ -853,17 +835,17 @@ LocalOp LocalComputationBuilder::Map(absl::Span<const LocalOp> operands,
|
|||||||
dimensions);
|
dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Reduce(
|
LocalOp ComputationBuilder::Reduce(
|
||||||
const LocalOp& operand, const LocalOp& init_value,
|
const LocalOp& operand, const LocalOp& init_value,
|
||||||
const LocalComputation& local_computation,
|
const Computation& local_computation,
|
||||||
absl::Span<const int64> dimensions_to_reduce) {
|
absl::Span<const int64> dimensions_to_reduce) {
|
||||||
return xla::Reduce(operand.op(), init_value.op(),
|
return xla::Reduce(operand.op(), init_value.op(),
|
||||||
local_computation.computation(), dimensions_to_reduce);
|
local_computation.computation(), dimensions_to_reduce);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
|
LocalOp ComputationBuilder::ReduceWindowWithGeneralPadding(
|
||||||
const LocalOp& operand, const LocalOp& init_value,
|
const LocalOp& operand, const LocalOp& init_value,
|
||||||
const LocalComputation& local_computation,
|
const Computation& local_computation,
|
||||||
absl::Span<const int64> window_dimensions,
|
absl::Span<const int64> window_dimensions,
|
||||||
absl::Span<const int64> window_strides,
|
absl::Span<const int64> window_strides,
|
||||||
absl::Span<const int64> base_dilations,
|
absl::Span<const int64> base_dilations,
|
||||||
@ -875,51 +857,50 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
|
|||||||
padding);
|
padding);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu,
|
LocalOp ComputationBuilder::RngNormal(const LocalOp& mu, const LocalOp& sigma,
|
||||||
const LocalOp& sigma,
|
const Shape& shape) {
|
||||||
const Shape& shape) {
|
|
||||||
return xla::RngNormal(mu.op(), sigma.op(), shape);
|
return xla::RngNormal(mu.op(), sigma.op(), shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b,
|
LocalOp ComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b,
|
||||||
const Shape& shape) {
|
const Shape& shape) {
|
||||||
return xla::RngUniform(a.op(), b.op(), shape);
|
return xla::RngUniform(a.op(), b.op(), shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::While(const LocalComputation& condition,
|
LocalOp ComputationBuilder::While(const Computation& condition,
|
||||||
const LocalComputation& body,
|
const Computation& body,
|
||||||
const LocalOp& init) {
|
const LocalOp& init) {
|
||||||
return xla::While(condition.computation(), body.computation(), init.op());
|
return xla::While(condition.computation(), body.computation(), init.op());
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Conditional(
|
LocalOp ComputationBuilder::Conditional(const LocalOp& predicate,
|
||||||
const LocalOp& predicate, const LocalOp& true_operand,
|
const LocalOp& true_operand,
|
||||||
const LocalComputation& true_computation, const LocalOp& false_operand,
|
const Computation& true_computation,
|
||||||
const LocalComputation& false_computation) {
|
const LocalOp& false_operand,
|
||||||
|
const Computation& false_computation) {
|
||||||
return xla::Conditional(predicate.op(), true_operand.op(),
|
return xla::Conditional(predicate.op(), true_operand.op(),
|
||||||
true_computation.computation(), false_operand.op(),
|
true_computation.computation(), false_operand.op(),
|
||||||
false_computation.computation());
|
false_computation.computation());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<bool> LocalComputationBuilder::IsConstant(const LocalOp& operand) {
|
StatusOr<bool> ComputationBuilder::IsConstant(const LocalOp& operand) {
|
||||||
return builder_.IsConstant(operand.op());
|
return builder_.IsConstant(operand.op());
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) {
|
LocalOp ComputationBuilder::Sort(const LocalOp& operand, int64 dimension) {
|
||||||
return xla::Sort(operand.op(), {}, dimension);
|
return xla::Sort(operand.op(), {}, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys,
|
LocalOp ComputationBuilder::SortKeyVal(const LocalOp& keys,
|
||||||
const LocalOp& values,
|
const LocalOp& values, int64 dimension) {
|
||||||
int64 dimension) {
|
|
||||||
return xla::Sort(keys.op(), {values.op()}, dimension);
|
return xla::Sort(keys.op(), {values.op()}, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Cholesky(const LocalOp& a) {
|
LocalOp ComputationBuilder::Cholesky(const LocalOp& a) {
|
||||||
return xla::Cholesky(a.op());
|
return xla::Cholesky(a.op());
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) {
|
LocalOp ComputationBuilder::QR(const LocalOp& a, bool full_matrices) {
|
||||||
XlaBuilder* builder = a.op().builder();
|
XlaBuilder* builder = a.op().builder();
|
||||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices));
|
TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices));
|
||||||
@ -927,17 +908,16 @@ LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a,
|
LocalOp ComputationBuilder::TriangularSolve(const LocalOp& a, const LocalOp& b,
|
||||||
const LocalOp& b,
|
bool left_side, bool lower,
|
||||||
bool left_side, bool lower,
|
bool unit_diagonal,
|
||||||
bool unit_diagonal,
|
int transpose_a) {
|
||||||
int transpose_a) {
|
|
||||||
return xla::TriangularSolve(
|
return xla::TriangularSolve(
|
||||||
a.op(), b.op(), left_side, lower, unit_diagonal,
|
a.op(), b.op(), left_side, lower, unit_diagonal,
|
||||||
xla::TriangularSolveOptions::Transpose(transpose_a));
|
xla::TriangularSolveOptions::Transpose(transpose_a));
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Gather(
|
LocalOp ComputationBuilder::Gather(
|
||||||
const LocalOp& input, const LocalOp& start_indices,
|
const LocalOp& input, const LocalOp& start_indices,
|
||||||
const GatherDimensionNumbers& dimension_numbers,
|
const GatherDimensionNumbers& dimension_numbers,
|
||||||
absl::Span<const int64> slice_sizes) {
|
absl::Span<const int64> slice_sizes) {
|
||||||
@ -945,24 +925,24 @@ LocalOp LocalComputationBuilder::Gather(
|
|||||||
slice_sizes);
|
slice_sizes);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalOp LocalComputationBuilder::Scatter(
|
LocalOp ComputationBuilder::Scatter(
|
||||||
const LocalOp& input, const LocalOp& scatter_indices,
|
const LocalOp& input, const LocalOp& scatter_indices,
|
||||||
const LocalOp& updates, const LocalComputation& update_computation,
|
const LocalOp& updates, const Computation& update_computation,
|
||||||
const ScatterDimensionNumbers& dimension_numbers) {
|
const ScatterDimensionNumbers& dimension_numbers) {
|
||||||
return xla::Scatter(input.op(), scatter_indices.op(), updates.op(),
|
return xla::Scatter(input.op(), scatter_indices.op(), updates.op(),
|
||||||
update_computation.computation(), dimension_numbers);
|
update_computation.computation(), dimension_numbers);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph(
|
StatusOr<Computation*> ComputationBuilder::BuildConstantSubGraph(
|
||||||
const LocalOp& operand) {
|
const LocalOp& operand) {
|
||||||
TF_ASSIGN_OR_RETURN(XlaComputation computation,
|
TF_ASSIGN_OR_RETURN(XlaComputation computation,
|
||||||
builder_.BuildConstantSubGraph(operand.op()));
|
builder_.BuildConstantSubGraph(operand.op()));
|
||||||
return new LocalComputation(std::move(computation));
|
return new Computation(std::move(computation));
|
||||||
}
|
}
|
||||||
|
|
||||||
#define _FORWARD(method_name, return_sig, args_sig, args) \
|
#define _FORWARD(method_name, return_sig, args_sig, args) \
|
||||||
return_sig LocalComputationBuilder::method_name args_sig { \
|
return_sig ComputationBuilder::method_name args_sig { \
|
||||||
return xla::method_name args; \
|
return xla::method_name args; \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define _FORWARD_UNOP(method_name) \
|
#define _FORWARD_UNOP(method_name) \
|
||||||
@ -1051,64 +1031,11 @@ void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) {
|
|||||||
|
|
||||||
void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; }
|
void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; }
|
||||||
|
|
||||||
void DeleteCompiledLocalComputation(CompiledLocalComputation* computation) {
|
void DeleteLocalExecutable(LocalExecutable* computation) { delete computation; }
|
||||||
delete computation;
|
|
||||||
}
|
|
||||||
|
|
||||||
void DeleteCompiledXrtComputation(CompiledXrtComputation* computation) {
|
void DeleteXrtExecutable(XrtExecutable* computation) { delete computation; }
|
||||||
delete computation;
|
|
||||||
}
|
|
||||||
|
|
||||||
void DeleteLocalComputation(LocalComputation* computation) {
|
void DeleteComputation(Computation* computation) { delete computation; }
|
||||||
delete computation;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<LocalShapedBufferTuple*> DestructureLocalShapedBufferTuple(
|
|
||||||
LocalShapedBuffer* local_shaped_buffer) {
|
|
||||||
const Shape tuple_shape = local_shaped_buffer->shape();
|
|
||||||
|
|
||||||
if (!tuple_shape.IsTuple()) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"Attemped to destructure a LocalShapedBuffer that did not have a tuple "
|
|
||||||
"shape; shape: %s",
|
|
||||||
ShapeUtil::HumanString(tuple_shape));
|
|
||||||
}
|
|
||||||
|
|
||||||
DeviceMemoryAllocator* allocator =
|
|
||||||
local_shaped_buffer->shaped_buffer()->memory_allocator();
|
|
||||||
ShapedBuffer tuple_buffer = local_shaped_buffer->Release();
|
|
||||||
|
|
||||||
// Extract some metadata we use to construct scoped buffers.
|
|
||||||
const se::Platform* platform = tuple_buffer.platform();
|
|
||||||
int device_ordinal = tuple_buffer.device_ordinal();
|
|
||||||
|
|
||||||
ShapeTree<se::DeviceMemoryBase>& shape_tree = tuple_buffer.buffers();
|
|
||||||
std::vector<LocalShapedBuffer*> results;
|
|
||||||
for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
|
|
||||||
// Create a shaped buffer for this destructured tuple element.
|
|
||||||
const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i});
|
|
||||||
VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape;
|
|
||||||
ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal);
|
|
||||||
|
|
||||||
ShapeUtil::ForEachSubshape(
|
|
||||||
subshape, [&](const Shape& s, const ShapeIndex& index) {
|
|
||||||
ShapeIndex original(index);
|
|
||||||
original.push_front(i);
|
|
||||||
se::DeviceMemoryBase* device_memory =
|
|
||||||
shape_tree.mutable_element(original);
|
|
||||||
shaped_buffer.set_buffer(*device_memory, index);
|
|
||||||
*device_memory = se::DeviceMemoryBase();
|
|
||||||
});
|
|
||||||
|
|
||||||
VLOG(3) << "Completed tuple element: " << i;
|
|
||||||
results.push_back(new LocalShapedBuffer(
|
|
||||||
ScopedShapedBuffer(std::move(shaped_buffer), allocator)));
|
|
||||||
}
|
|
||||||
// Deallocate the root buffer.
|
|
||||||
se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer();
|
|
||||||
TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer));
|
|
||||||
return new LocalShapedBufferTuple(std::move(results));
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
|
StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
|
||||||
XrtAllocation* allocation, const string& session_target) {
|
XrtAllocation* allocation, const string& session_target) {
|
||||||
|
@ -35,42 +35,42 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace swig {
|
namespace swig {
|
||||||
|
|
||||||
// Initializes the number of replicas that XLA will be initialized with (when
|
|
||||||
// first obtaining a handle to the local XLA service). If this is called after
|
|
||||||
// the handle to the local XLA service has been established, then an error is
|
|
||||||
// returned.
|
|
||||||
Status InitializeReplicaCount(int replica_count);
|
|
||||||
|
|
||||||
// Initializes the platform name that XLA will be initialized with (when
|
|
||||||
// first obtaining a handle to the local XLA service). If this is called after
|
|
||||||
// the handle to the local XLA service has been established, then an error is
|
|
||||||
// returned.
|
|
||||||
Status InitializePlatformName(const string& platform_name);
|
|
||||||
|
|
||||||
// Returns the replica count that is currently set, regardless of whether the
|
|
||||||
// local XLA service has been instantiated yet or not.
|
|
||||||
int GetReplicaCount();
|
|
||||||
|
|
||||||
// Registers a 'fn_capsule' as a CPU custom call target.
|
// Registers a 'fn_capsule' as a CPU custom call target.
|
||||||
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
|
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
|
||||||
// "xla._CPU_CUSTOM_CALL_TARGET".
|
// "xla._CPU_CUSTOM_CALL_TARGET".
|
||||||
Status RegisterCpuCustomCallTarget(const string& name, PyObject* fn_capsule);
|
Status RegisterCpuCustomCallTarget(const string& name, PyObject* fn_capsule);
|
||||||
|
|
||||||
// Wraps the local client's infeed-transfer function.
|
// Wrapper around an xla::LocalClient.
|
||||||
//
|
class LocalClient {
|
||||||
// The default device ordinal (0) is used.
|
public:
|
||||||
Status TransferToInfeedLocal(const Literal& literal);
|
// Initializes a local XLA client for `platform_name`. Returns an error if no
|
||||||
|
/// such platform exists, or if the platform has no visible devices.
|
||||||
|
static StatusOr<LocalClient> Get(const string& platform_name);
|
||||||
|
|
||||||
// Transfers the given literal to the infeed of the given replica.
|
// Copyable and moveable; the class is just a wrapper around a
|
||||||
//
|
// xla::LocalClient pointer for convenient SWIG wrapping.
|
||||||
// The replica number is resolved to an appropriate device ordinal.
|
|
||||||
Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number);
|
|
||||||
|
|
||||||
// Transfers a literal of the given shape from the outfeed of the given replica.
|
// Returns the number of devices known to the XLA client.
|
||||||
//
|
int DeviceCount() const;
|
||||||
// The replica number is resolved to an appropriate device ordinal.
|
|
||||||
StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
|
// Wraps the local client's infeed-transfer function.
|
||||||
int replica_number);
|
//
|
||||||
|
// The default device ordinal (0) is used.
|
||||||
|
Status TransferToInfeed(const Literal& literal, int device_ordinal);
|
||||||
|
|
||||||
|
// Transfers a literal of the given shape from the outfeed of the given
|
||||||
|
// replica.
|
||||||
|
StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_ordinal);
|
||||||
|
|
||||||
|
xla::LocalClient* client() const { return client_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
LocalClient(xla::LocalClient* client);
|
||||||
|
|
||||||
|
xla::LocalClient* client_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class LocalShapedBufferTuple;
|
||||||
|
|
||||||
// Represents a reference to literals that live in a device-allocated buffer via
|
// Represents a reference to literals that live in a device-allocated buffer via
|
||||||
// XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a
|
// XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a
|
||||||
@ -79,9 +79,9 @@ class LocalShapedBuffer {
|
|||||||
public:
|
public:
|
||||||
static StatusOr<LocalShapedBuffer*> FromLiteral(
|
static StatusOr<LocalShapedBuffer*> FromLiteral(
|
||||||
const Literal& argument, const absl::optional<Shape>& shape_with_layout,
|
const Literal& argument, const absl::optional<Shape>& shape_with_layout,
|
||||||
int replica_number);
|
const LocalClient& client, int device_ordinal);
|
||||||
|
|
||||||
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer);
|
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, xla::LocalClient* client);
|
||||||
StatusOr<Literal> ToLiteral() const;
|
StatusOr<Literal> ToLiteral() const;
|
||||||
const Shape& shape() const;
|
const Shape& shape() const;
|
||||||
const ScopedShapedBuffer* shaped_buffer() const;
|
const ScopedShapedBuffer* shaped_buffer() const;
|
||||||
@ -90,8 +90,13 @@ class LocalShapedBuffer {
|
|||||||
// analogous to std::unique_ptr::release().
|
// analogous to std::unique_ptr::release().
|
||||||
ShapedBuffer Release();
|
ShapedBuffer Release();
|
||||||
|
|
||||||
|
// Destructures a tuple-valued LocalShapedBuffer into its constitutent
|
||||||
|
// elements in LocalShapedBufferTuple form.
|
||||||
|
StatusOr<LocalShapedBufferTuple*> DestructureTuple();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ScopedShapedBuffer shaped_buffer_;
|
ScopedShapedBuffer shaped_buffer_;
|
||||||
|
xla::LocalClient* client_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Result of a tuple destructuring operation on a LocalShapedBuffer -- this
|
// Result of a tuple destructuring operation on a LocalShapedBuffer -- this
|
||||||
@ -117,11 +122,6 @@ class LocalShapedBufferTuple {
|
|||||||
std::vector<LocalShapedBuffer*> elements_;
|
std::vector<LocalShapedBuffer*> elements_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Destructures a tuple-valued LocalShapedBuffer into its constitutent elements
|
|
||||||
// in LocalShapedBufferTuple form.
|
|
||||||
StatusOr<LocalShapedBufferTuple*> DestructureLocalShapedBufferTuple(
|
|
||||||
LocalShapedBuffer* local_shaped_buffer);
|
|
||||||
|
|
||||||
// Represents a reference to literals that live in a device-allocated buffer via
|
// Represents a reference to literals that live in a device-allocated buffer via
|
||||||
// XRT. Specifically, wraps an int64 handle produced by running the allocation
|
// XRT. Specifically, wraps an int64 handle produced by running the allocation
|
||||||
// graph, and an XLA shape to track the referent's shape.
|
// graph, and an XLA shape to track the referent's shape.
|
||||||
@ -176,14 +176,19 @@ StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
|
|||||||
|
|
||||||
// Represents a compiled computation that can be executed given handles to
|
// Represents a compiled computation that can be executed given handles to
|
||||||
// device-allocated literals. Specifically, wraps an XLA LocalExecutable.
|
// device-allocated literals. Specifically, wraps an XLA LocalExecutable.
|
||||||
class CompiledLocalComputation {
|
class LocalExecutable {
|
||||||
public:
|
public:
|
||||||
CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable);
|
LocalExecutable(std::unique_ptr<xla::LocalExecutable> executable,
|
||||||
|
xla::DeviceAssignment device_assignment,
|
||||||
|
xla::LocalClient* client);
|
||||||
|
|
||||||
int num_replicas() const {
|
int num_replicas() const {
|
||||||
return executable_->build_options().num_replicas();
|
return executable_->build_options().num_replicas();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the device ordinals to which each replica is assigned.
|
||||||
|
std::vector<int> DeviceOrdinals() const;
|
||||||
|
|
||||||
StatusOr<LocalShapedBuffer*> Execute(
|
StatusOr<LocalShapedBuffer*> Execute(
|
||||||
absl::Span<LocalShapedBuffer* const> argument_handles);
|
absl::Span<LocalShapedBuffer* const> argument_handles);
|
||||||
|
|
||||||
@ -194,18 +199,22 @@ class CompiledLocalComputation {
|
|||||||
absl::Span<const std::vector<LocalShapedBuffer*> > argument_handles);
|
absl::Span<const std::vector<LocalShapedBuffer*> > argument_handles);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<LocalExecutable> executable_;
|
const std::unique_ptr<xla::LocalExecutable> executable_;
|
||||||
|
const xla::DeviceAssignment device_assignment_;
|
||||||
|
xla::LocalClient* const client_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Represents a compiled computation that can be executed given handles to
|
// Represents a compiled computation that can be executed given handles to
|
||||||
// device-allocated literals. Specifically, wraps an XRT computation handle.
|
// device-allocated literals. Specifically, wraps an XRT computation handle.
|
||||||
class CompiledXrtComputation {
|
class XrtExecutable {
|
||||||
public:
|
public:
|
||||||
// Accepts a `session_target` argument, used in constructing the
|
// Accepts a `session_target` argument, used in constructing the
|
||||||
// `tensorflow::ClientSession` instance in which the execution graph is run.
|
// `tensorflow::ClientSession` instance in which the execution graph is run.
|
||||||
CompiledXrtComputation(const ProgramShape& program_shape, int64 handle,
|
XrtExecutable(const ProgramShape& program_shape, int64 handle,
|
||||||
const string& session_target);
|
const string& session_target);
|
||||||
~CompiledXrtComputation();
|
~XrtExecutable();
|
||||||
|
|
||||||
|
std::vector<int> DeviceOrdinals() const { return {0}; }
|
||||||
|
|
||||||
StatusOr<XrtAllocation*> Execute(
|
StatusOr<XrtAllocation*> Execute(
|
||||||
absl::Span<XrtAllocation* const> argument_handles);
|
absl::Span<XrtAllocation* const> argument_handles);
|
||||||
@ -219,21 +228,21 @@ class CompiledXrtComputation {
|
|||||||
const string session_target_;
|
const string session_target_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Wraps a XlaComputation produced by a LocalComputationBuilder. The
|
// Wraps a XlaComputation produced by a ComputationBuilder. The
|
||||||
// Compile method compiles the computation to a (local) executable via
|
// Compile method compiles the computation to a (local) executable via
|
||||||
// the client library's local client. This class is intended to be
|
// the client library's local client. This class is intended to be
|
||||||
// made available to Python via SWIG.
|
// made available to Python via SWIG.
|
||||||
class LocalComputation {
|
class Computation {
|
||||||
public:
|
public:
|
||||||
LocalComputation(XlaComputation computation);
|
Computation(XlaComputation computation);
|
||||||
|
|
||||||
StatusOr<CompiledLocalComputation*> Compile(
|
StatusOr<LocalExecutable*> Compile(
|
||||||
const std::vector<Shape>& argument_shapes,
|
const std::vector<Shape>& argument_shapes,
|
||||||
const ExecutableBuildOptions* build_options);
|
const ExecutableBuildOptions* build_options, const LocalClient& client);
|
||||||
|
|
||||||
// Accepts a `session_target` argument, used in constructing the
|
// Accepts a `session_target` argument, used in constructing the
|
||||||
// `tensorflow::ClientSession` instance in which the compilation graph is run.
|
// `tensorflow::ClientSession` instance in which the compilation graph is run.
|
||||||
StatusOr<CompiledXrtComputation*> CompileForXrt(
|
StatusOr<XrtExecutable*> CompileForXrt(
|
||||||
const std::vector<Shape>& argument_shapes, const string& session_target);
|
const std::vector<Shape>& argument_shapes, const string& session_target);
|
||||||
|
|
||||||
const XlaComputation& computation() const;
|
const XlaComputation& computation() const;
|
||||||
@ -243,6 +252,9 @@ class LocalComputation {
|
|||||||
// string on failure.
|
// string on failure.
|
||||||
string GetSerializedProto() const;
|
string GetSerializedProto() const;
|
||||||
|
|
||||||
|
// Returns the program shape for this computation.
|
||||||
|
StatusOr<ProgramShape> GetProgramShape() const;
|
||||||
|
|
||||||
// Returns the return-value shape for this computation.
|
// Returns the return-value shape for this computation.
|
||||||
StatusOr<Shape> GetReturnValueShape() const;
|
StatusOr<Shape> GetReturnValueShape() const;
|
||||||
|
|
||||||
@ -250,7 +262,7 @@ class LocalComputation {
|
|||||||
XlaComputation computation_;
|
XlaComputation computation_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended
|
// Wraps a XlaOp produced by a ComputationBuilder. This class is intended
|
||||||
// to be made available to Python via SWIG.
|
// to be made available to Python via SWIG.
|
||||||
class LocalOp {
|
class LocalOp {
|
||||||
public:
|
public:
|
||||||
@ -267,20 +279,20 @@ class LocalOp {
|
|||||||
// Python.
|
// Python.
|
||||||
// - Set up the underlying builder to use the client library's
|
// - Set up the underlying builder to use the client library's
|
||||||
// LocalClient.
|
// LocalClient.
|
||||||
// - Wrap Computations in LocalComputations for Python access.
|
// - Wrap Computations in Computations for Python access.
|
||||||
// - Correspondingly unwrap incoming LocalComputations.
|
// - Correspondingly unwrap incoming Computations.
|
||||||
class LocalComputationBuilder {
|
class ComputationBuilder {
|
||||||
public:
|
public:
|
||||||
LocalComputationBuilder(const string& computation_name);
|
ComputationBuilder(const string& computation_name);
|
||||||
|
|
||||||
void SetOpMetadata(const OpMetadata& metadata);
|
void SetOpMetadata(const OpMetadata& metadata);
|
||||||
void ClearOpMetadata();
|
void ClearOpMetadata();
|
||||||
|
|
||||||
// Returns an owned LocalComputation to the caller on success.
|
// Returns an owned Computation to the caller on success.
|
||||||
StatusOr<LocalComputation*> Build();
|
StatusOr<Computation*> Build();
|
||||||
|
|
||||||
// Returns an owned LocalComputation to the caller on success with given root.
|
// Returns an owned Computation to the caller on success with given root.
|
||||||
StatusOr<LocalComputation*> BuildWithRoot(const LocalOp& root);
|
StatusOr<Computation*> BuildWithRoot(const LocalOp& root);
|
||||||
|
|
||||||
LocalOp Parameter(int64 parameter_number, const Shape& shape,
|
LocalOp Parameter(int64 parameter_number, const Shape& shape,
|
||||||
const string& name);
|
const string& name);
|
||||||
@ -339,11 +351,11 @@ class LocalComputationBuilder {
|
|||||||
LocalOp ConcatInDim(absl::Span<const LocalOp> operands, int64 dimension);
|
LocalOp ConcatInDim(absl::Span<const LocalOp> operands, int64 dimension);
|
||||||
|
|
||||||
LocalOp SelectAndScatterWithGeneralPadding(
|
LocalOp SelectAndScatterWithGeneralPadding(
|
||||||
const LocalOp& operand, const LocalComputation& select,
|
const LocalOp& operand, const Computation& select,
|
||||||
absl::Span<const int64> window_dimensions,
|
absl::Span<const int64> window_dimensions,
|
||||||
absl::Span<const int64> window_strides,
|
absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64> > padding, const LocalOp& source,
|
absl::Span<const std::pair<int64, int64> > padding, const LocalOp& source,
|
||||||
const LocalOp& init_value, const LocalComputation& scatter);
|
const LocalOp& init_value, const Computation& scatter);
|
||||||
|
|
||||||
LocalOp Tuple(absl::Span<const LocalOp> elements);
|
LocalOp Tuple(absl::Span<const LocalOp> elements);
|
||||||
|
|
||||||
@ -369,7 +381,7 @@ class LocalComputationBuilder {
|
|||||||
LocalOp BitcastConvertType(const LocalOp& operand,
|
LocalOp BitcastConvertType(const LocalOp& operand,
|
||||||
PrimitiveType new_element_type);
|
PrimitiveType new_element_type);
|
||||||
|
|
||||||
LocalOp Call(const LocalComputation& local_computation,
|
LocalOp Call(const Computation& local_computation,
|
||||||
absl::Span<const LocalOp> operands);
|
absl::Span<const LocalOp> operands);
|
||||||
|
|
||||||
LocalOp CustomCall(const string& call_target_name,
|
LocalOp CustomCall(const string& call_target_name,
|
||||||
@ -384,16 +396,16 @@ class LocalComputationBuilder {
|
|||||||
LocalOp Rev(const LocalOp& operand, absl::Span<const int64> dimensions);
|
LocalOp Rev(const LocalOp& operand, absl::Span<const int64> dimensions);
|
||||||
|
|
||||||
LocalOp Map(absl::Span<const LocalOp> operands,
|
LocalOp Map(absl::Span<const LocalOp> operands,
|
||||||
const LocalComputation& local_computation,
|
const Computation& local_computation,
|
||||||
absl::Span<const int64> dimensions);
|
absl::Span<const int64> dimensions);
|
||||||
|
|
||||||
LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value,
|
LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value,
|
||||||
const LocalComputation& local_computation,
|
const Computation& local_computation,
|
||||||
absl::Span<const int64> dimensions_to_reduce);
|
absl::Span<const int64> dimensions_to_reduce);
|
||||||
|
|
||||||
LocalOp ReduceWindowWithGeneralPadding(
|
LocalOp ReduceWindowWithGeneralPadding(
|
||||||
const LocalOp& operand, const LocalOp& init_value,
|
const LocalOp& operand, const LocalOp& init_value,
|
||||||
const LocalComputation& local_computation,
|
const Computation& local_computation,
|
||||||
absl::Span<const int64> window_dimensions,
|
absl::Span<const int64> window_dimensions,
|
||||||
absl::Span<const int64> window_strides,
|
absl::Span<const int64> window_strides,
|
||||||
absl::Span<const int64> base_dilations,
|
absl::Span<const int64> base_dilations,
|
||||||
@ -405,13 +417,13 @@ class LocalComputationBuilder {
|
|||||||
|
|
||||||
LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape);
|
LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape);
|
||||||
|
|
||||||
LocalOp While(const LocalComputation& condition, const LocalComputation& body,
|
LocalOp While(const Computation& condition, const Computation& body,
|
||||||
const LocalOp& init);
|
const LocalOp& init);
|
||||||
|
|
||||||
LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand,
|
LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand,
|
||||||
const LocalComputation& true_computation,
|
const Computation& true_computation,
|
||||||
const LocalOp& false_operand,
|
const LocalOp& false_operand,
|
||||||
const LocalComputation& false_computation);
|
const Computation& false_computation);
|
||||||
|
|
||||||
StatusOr<bool> IsConstant(const LocalOp& operand);
|
StatusOr<bool> IsConstant(const LocalOp& operand);
|
||||||
|
|
||||||
@ -435,11 +447,10 @@ class LocalComputationBuilder {
|
|||||||
absl::Span<const int64> slice_sizes);
|
absl::Span<const int64> slice_sizes);
|
||||||
|
|
||||||
LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices,
|
LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices,
|
||||||
const LocalOp& updates,
|
const LocalOp& updates, const Computation& update_computation,
|
||||||
const LocalComputation& update_computation,
|
|
||||||
const ScatterDimensionNumbers& dimension_numbers);
|
const ScatterDimensionNumbers& dimension_numbers);
|
||||||
|
|
||||||
StatusOr<LocalComputation*> BuildConstantSubGraph(const LocalOp& operand);
|
StatusOr<Computation*> BuildConstantSubGraph(const LocalOp& operand);
|
||||||
|
|
||||||
#define _FORWARD(method_name, return_sig, args_sig) \
|
#define _FORWARD(method_name, return_sig, args_sig) \
|
||||||
return_sig method_name args_sig;
|
return_sig method_name args_sig;
|
||||||
@ -529,9 +540,9 @@ class LocalComputationBuilder {
|
|||||||
// Functions for freeing resources from the Python side.
|
// Functions for freeing resources from the Python side.
|
||||||
void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer);
|
void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer);
|
||||||
void DeleteXrtAllocation(XrtAllocation* allocation);
|
void DeleteXrtAllocation(XrtAllocation* allocation);
|
||||||
void DeleteCompiledLocalComputation(CompiledLocalComputation* computation);
|
void DeleteLocalExecutable(LocalExecutable* computation);
|
||||||
void DeleteCompiledXrtComputation(CompiledXrtComputation* computation);
|
void DeleteXrtExecutable(XrtExecutable* computation);
|
||||||
void DeleteLocalComputation(LocalComputation* computation);
|
void DeleteComputation(Computation* computation);
|
||||||
|
|
||||||
} // namespace swig
|
} // namespace swig
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -23,11 +23,13 @@ limitations under the License.
|
|||||||
// C++ Python
|
// C++ Python
|
||||||
// -------------------------------------+---------------------------------------
|
// -------------------------------------+---------------------------------------
|
||||||
// Span<int64> <- sequence of int
|
// Span<int64> <- sequence of int
|
||||||
|
// vector<int> -> sequence of int
|
||||||
// Span<LocalOp> <- sequence of LocalOp
|
// Span<LocalOp> <- sequence of LocalOp
|
||||||
// Literal <-> (nested tuple of) numpy ndarray
|
// Literal <-> (nested tuple of) numpy ndarray
|
||||||
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
|
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
|
||||||
// Shape -> pair holding (dtype, dimensions)
|
// Shape -> pair holding (dtype, dimensions)
|
||||||
// <- object duck-typed as xla_client.Shape
|
// <- object duck-typed as xla_client.Shape
|
||||||
|
// ProgramShape -> pair of ([arg_shapes], ret_shape)
|
||||||
// std::vector<Shape> <- sequence of xla_client.Shape objects
|
// std::vector<Shape> <- sequence of xla_client.Shape objects
|
||||||
// PrimitiveType <- int
|
// PrimitiveType <- int
|
||||||
// Span<pair<int64, in64>> <- sequence of int pairs
|
// Span<pair<int64, in64>> <- sequence of int pairs
|
||||||
@ -97,7 +99,7 @@ limitations under the License.
|
|||||||
// wrapped in a Python class (xla_client.Shape) so as not to expose
|
// wrapped in a Python class (xla_client.Shape) so as not to expose
|
||||||
// the raw pair externally.
|
// the raw pair externally.
|
||||||
//
|
//
|
||||||
// Other SWIG object wrappers (e.g. of LocalComputation) are further
|
// Other SWIG object wrappers (e.g. of Computation) are further
|
||||||
// wrapped by xla_client in order to set up a custom destructor that
|
// wrapped by xla_client in order to set up a custom destructor that
|
||||||
// triggers memory deallocation on the C++ side.
|
// triggers memory deallocation on the C++ side.
|
||||||
|
|
||||||
@ -214,6 +216,15 @@ tensorflow::ImportNumpy();
|
|||||||
|
|
||||||
// Basic types
|
// Basic types
|
||||||
|
|
||||||
|
|
||||||
|
%typemap(out) std::vector<int> {
|
||||||
|
PyObject* out = PyList_New($1.size());
|
||||||
|
for (int i = 0; i < $1.size(); ++i) {
|
||||||
|
PyList_SET_ITEM(out, i, PyInt_FromLong($1[i]));
|
||||||
|
}
|
||||||
|
$result = out;
|
||||||
|
}
|
||||||
|
|
||||||
%typemap(out) StatusOr<bool> {
|
%typemap(out) StatusOr<bool> {
|
||||||
if ($1.ok()) {
|
if ($1.ok()) {
|
||||||
$result = PyBool_FromLong($1.ConsumeValueOrDie());
|
$result = PyBool_FromLong($1.ConsumeValueOrDie());
|
||||||
@ -287,12 +298,12 @@ tensorflow::ImportNumpy();
|
|||||||
|
|
||||||
// Computation and buffer/allocation types
|
// Computation and buffer/allocation types
|
||||||
|
|
||||||
%typemap(out) StatusOr<xla::swig::CompiledLocalComputation*> {
|
%typemap(out) StatusOr<xla::swig::LocalClient> {
|
||||||
if ($1.ok()) {
|
if ($1.ok()) {
|
||||||
auto* value = $1.ValueOrDie();
|
xla::swig::LocalClient value = $1.ValueOrDie();
|
||||||
{
|
{
|
||||||
auto* $1 = value;
|
auto $1 = value;
|
||||||
$typemap(out, xla::swig::CompiledLocalComputation*)
|
$typemap(out, xla::swig::LocalClient)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
||||||
@ -300,12 +311,25 @@ tensorflow::ImportNumpy();
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
%typemap(out) StatusOr<xla::swig::CompiledXrtComputation*> {
|
%typemap(out) StatusOr<xla::swig::LocalExecutable*> {
|
||||||
if ($1.ok()) {
|
if ($1.ok()) {
|
||||||
auto* value = $1.ValueOrDie();
|
auto* value = $1.ValueOrDie();
|
||||||
{
|
{
|
||||||
auto* $1 = value;
|
auto* $1 = value;
|
||||||
$typemap(out, xla::swig::CompiledXrtComputation*)
|
$typemap(out, xla::swig::LocalExecutable*)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
||||||
|
SWIG_fail;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
%typemap(out) StatusOr<xla::swig::XrtExecutable*> {
|
||||||
|
if ($1.ok()) {
|
||||||
|
auto* value = $1.ValueOrDie();
|
||||||
|
{
|
||||||
|
auto* $1 = value;
|
||||||
|
$typemap(out, xla::swig::XrtExecutable*)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
||||||
@ -365,12 +389,12 @@ tensorflow::ImportNumpy();
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
%typemap(out) StatusOr<xla::swig::LocalComputation*> {
|
%typemap(out) StatusOr<xla::swig::Computation*> {
|
||||||
if ($1.ok()) {
|
if ($1.ok()) {
|
||||||
auto* value = $1.ValueOrDie();
|
auto* value = $1.ValueOrDie();
|
||||||
{
|
{
|
||||||
auto* $1 = value;
|
auto* $1 = value;
|
||||||
$typemap(out, xla::swig::LocalComputation*)
|
$typemap(out, xla::swig::Computation*)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
||||||
@ -519,18 +543,30 @@ tensorflow::ImportNumpy();
|
|||||||
// Shape
|
// Shape
|
||||||
|
|
||||||
%typemap(out) const Shape& {
|
%typemap(out) const Shape& {
|
||||||
$result = numpy::PyShapeInfoFromXlaShape(*$1);
|
$result = numpy::PyShapeInfoFromXlaShape(*$1).release();
|
||||||
}
|
}
|
||||||
|
|
||||||
%typemap(out) StatusOr<Shape> {
|
%typemap(out) StatusOr<Shape> {
|
||||||
if ($1.ok()) {
|
if ($1.ok()) {
|
||||||
$result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie());
|
$result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()).release();
|
||||||
} else {
|
} else {
|
||||||
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
||||||
SWIG_fail;
|
SWIG_fail;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
%typemap(out) StatusOr<ProgramShape> {
|
||||||
|
if ($1.ok()) {
|
||||||
|
$result = numpy::PyProgramShapeInfoFromXlaProgramShape(
|
||||||
|
$1.ConsumeValueOrDie()).release();
|
||||||
|
} else {
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
||||||
|
SWIG_fail;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
%typemap(in) const Shape& (Shape temp) {
|
%typemap(in) const Shape& (Shape temp) {
|
||||||
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
|
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
|
||||||
if (!statusor.ok()) {
|
if (!statusor.ok()) {
|
||||||
@ -558,7 +594,7 @@ tensorflow::ImportNumpy();
|
|||||||
}
|
}
|
||||||
|
|
||||||
%typemap(out) std::unique_ptr<Shape> {
|
%typemap(out) std::unique_ptr<Shape> {
|
||||||
$result = numpy::PyShapeInfoFromXlaShape(*$1);
|
$result = numpy::PyShapeInfoFromXlaShape(*$1).release();
|
||||||
}
|
}
|
||||||
|
|
||||||
%typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) {
|
%typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) {
|
||||||
@ -966,17 +1002,17 @@ tensorflow::ImportNumpy();
|
|||||||
%ignoreall
|
%ignoreall
|
||||||
%unignore xla;
|
%unignore xla;
|
||||||
%unignore xla::swig;
|
%unignore xla::swig;
|
||||||
%unignore xla::swig::InitializeReplicaCount;
|
|
||||||
%unignore xla::swig::InitializePlatformName;
|
|
||||||
%unignore xla::swig::GetReplicaCount;
|
|
||||||
%unignore xla::swig::RegisterCpuCustomCallTarget;
|
%unignore xla::swig::RegisterCpuCustomCallTarget;
|
||||||
%unignore xla::swig::TransferToInfeedLocal;
|
%unignore xla::swig::LocalClient;
|
||||||
%unignore xla::swig::TransferToInfeedLocalReplica;
|
%unignore xla::swig::LocalClient::Get;
|
||||||
%unignore xla::swig::TransferFromOutfeedLocalReplica;
|
%unignore xla::swig::LocalClient::DeviceCount;
|
||||||
|
%unignore xla::swig::LocalClient::TransferToInfeed;
|
||||||
|
%unignore xla::swig::LocalClient::TransferFromOutfeed;
|
||||||
%unignore xla::swig::LocalShapedBuffer;
|
%unignore xla::swig::LocalShapedBuffer;
|
||||||
%unignore xla::swig::LocalShapedBuffer::FromLiteral;
|
%unignore xla::swig::LocalShapedBuffer::FromLiteral;
|
||||||
%unignore xla::swig::LocalShapedBuffer::ToLiteral;
|
%unignore xla::swig::LocalShapedBuffer::ToLiteral;
|
||||||
%unignore xla::swig::LocalShapedBuffer::shape;
|
%unignore xla::swig::LocalShapedBuffer::shape;
|
||||||
|
%unignore xla::swig::LocalShapedBuffer::DestructureTuple;
|
||||||
%unignore xla::swig::LocalShapedBufferTuple;
|
%unignore xla::swig::LocalShapedBufferTuple;
|
||||||
%unignore xla::swig::LocalShapedBufferTuple::Release;
|
%unignore xla::swig::LocalShapedBufferTuple::Release;
|
||||||
%unignore xla::swig::LocalShapedBufferTuple::size;
|
%unignore xla::swig::LocalShapedBufferTuple::size;
|
||||||
@ -987,139 +1023,141 @@ tensorflow::ImportNumpy();
|
|||||||
%unignore xla::swig::XrtAllocationTuple;
|
%unignore xla::swig::XrtAllocationTuple;
|
||||||
%unignore xla::swig::XrtAllocationTuple::Release;
|
%unignore xla::swig::XrtAllocationTuple::Release;
|
||||||
%unignore xla::swig::XrtAllocationTuple::size;
|
%unignore xla::swig::XrtAllocationTuple::size;
|
||||||
%unignore xla::swig::CompiledLocalComputation;
|
%unignore xla::swig::LocalExecutable;
|
||||||
%unignore xla::swig::CompiledLocalComputation::Execute;
|
%unignore xla::swig::LocalExecutable::DeviceOrdinals;
|
||||||
%unignore xla::swig::CompiledLocalComputation::ExecutePerReplica;
|
%unignore xla::swig::LocalExecutable::Execute;
|
||||||
%unignore xla::swig::CompiledXrtComputation;
|
%unignore xla::swig::LocalExecutable::ExecutePerReplica;
|
||||||
%unignore xla::swig::CompiledXrtComputation::Execute;
|
%unignore xla::swig::XrtExecutable;
|
||||||
%unignore xla::swig::LocalComputation;
|
%unignore xla::swig::XrtExecutable::DeviceOrdinals;
|
||||||
%unignore xla::swig::LocalComputation::Compile;
|
%unignore xla::swig::XrtExecutable::Execute;
|
||||||
%unignore xla::swig::LocalComputation::CompileForXrt;
|
%unignore xla::swig::Computation;
|
||||||
%unignore xla::swig::LocalComputation::GetReturnValueShape;
|
%unignore xla::swig::Computation::Compile;
|
||||||
%unignore xla::swig::LocalComputation::GetSerializedProto;
|
%unignore xla::swig::Computation::CompileForXrt;
|
||||||
|
%unignore xla::swig::Computation::GetProgramShape;
|
||||||
|
%unignore xla::swig::Computation::GetReturnValueShape;
|
||||||
|
%unignore xla::swig::Computation::GetSerializedProto;
|
||||||
%unignore xla::swig::LocalOp;
|
%unignore xla::swig::LocalOp;
|
||||||
%unignore xla::swig::LocalComputationBuilder;
|
%unignore xla::swig::ComputationBuilder;
|
||||||
%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder;
|
%unignore xla::swig::ComputationBuilder::ComputationBuilder;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Build;
|
%unignore xla::swig::ComputationBuilder::Build;
|
||||||
%unignore xla::swig::LocalComputationBuilder::BuildWithRoot;
|
%unignore xla::swig::ComputationBuilder::BuildWithRoot;
|
||||||
%unignore xla::swig::LocalComputationBuilder::SetOpMetadata;
|
%unignore xla::swig::ComputationBuilder::SetOpMetadata;
|
||||||
%unignore xla::swig::LocalComputationBuilder::ClearOpMetadata;
|
%unignore xla::swig::ComputationBuilder::ClearOpMetadata;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Parameter;
|
%unignore xla::swig::ComputationBuilder::Parameter;
|
||||||
%unignore xla::swig::LocalComputationBuilder::GetShape;
|
%unignore xla::swig::ComputationBuilder::GetShape;
|
||||||
%unignore xla::swig::LocalComputationBuilder::GetReturnValueShape;
|
%unignore xla::swig::ComputationBuilder::GetReturnValueShape;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Infeed;
|
%unignore xla::swig::ComputationBuilder::Infeed;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Outfeed;
|
%unignore xla::swig::ComputationBuilder::Outfeed;
|
||||||
%unignore xla::swig::LocalComputationBuilder::ConstantLiteral;
|
%unignore xla::swig::ComputationBuilder::ConstantLiteral;
|
||||||
%unignore xla::swig::LocalComputationBuilder::ConstantR0;
|
%unignore xla::swig::ComputationBuilder::ConstantR0;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Iota;
|
%unignore xla::swig::ComputationBuilder::Iota;
|
||||||
%unignore xla::swig::LocalComputationBuilder::BroadcastedIota;
|
%unignore xla::swig::ComputationBuilder::BroadcastedIota;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Broadcast;
|
%unignore xla::swig::ComputationBuilder::Broadcast;
|
||||||
%unignore xla::swig::LocalComputationBuilder::BroadcastInDim;
|
%unignore xla::swig::ComputationBuilder::BroadcastInDim;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Pad;
|
%unignore xla::swig::ComputationBuilder::Pad;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Reshape;
|
%unignore xla::swig::ComputationBuilder::Reshape;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Collapse;
|
%unignore xla::swig::ComputationBuilder::Collapse;
|
||||||
%unignore xla::swig::LocalComputationBuilder::AllToAll;
|
%unignore xla::swig::ComputationBuilder::AllToAll;
|
||||||
%unignore xla::swig::LocalComputationBuilder::CrossReplicaSum;
|
%unignore xla::swig::ComputationBuilder::CrossReplicaSum;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Slice;
|
%unignore xla::swig::ComputationBuilder::Slice;
|
||||||
%unignore xla::swig::LocalComputationBuilder::SliceInDim;
|
%unignore xla::swig::ComputationBuilder::SliceInDim;
|
||||||
%unignore xla::swig::LocalComputationBuilder::DynamicSlice;
|
%unignore xla::swig::ComputationBuilder::DynamicSlice;
|
||||||
%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice;
|
%unignore xla::swig::ComputationBuilder::DynamicUpdateSlice;
|
||||||
%unignore xla::swig::LocalComputationBuilder::ConcatInDim;
|
%unignore xla::swig::ComputationBuilder::ConcatInDim;
|
||||||
%unignore xla::swig::LocalComputationBuilder::SelectAndScatterWithGeneralPadding;
|
%unignore xla::swig::ComputationBuilder::SelectAndScatterWithGeneralPadding;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Select;
|
%unignore xla::swig::ComputationBuilder::Select;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Tuple;
|
%unignore xla::swig::ComputationBuilder::Tuple;
|
||||||
%unignore xla::swig::LocalComputationBuilder::GetTupleElement;
|
%unignore xla::swig::ComputationBuilder::GetTupleElement;
|
||||||
%unignore xla::swig::LocalComputationBuilder::ConvertElementType;
|
%unignore xla::swig::ComputationBuilder::ConvertElementType;
|
||||||
%unignore xla::swig::LocalComputationBuilder::BitcastConvertType;
|
%unignore xla::swig::ComputationBuilder::BitcastConvertType;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Call;
|
%unignore xla::swig::ComputationBuilder::Call;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Transpose;
|
%unignore xla::swig::ComputationBuilder::Transpose;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Rev;
|
%unignore xla::swig::ComputationBuilder::Rev;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Clamp;
|
%unignore xla::swig::ComputationBuilder::Clamp;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Map;
|
%unignore xla::swig::ComputationBuilder::Map;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Reduce;
|
%unignore xla::swig::ComputationBuilder::Reduce;
|
||||||
%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding;
|
%unignore xla::swig::ComputationBuilder::ReduceWindowWithGeneralPadding;
|
||||||
%unignore xla::swig::LocalComputationBuilder::RngNormal;
|
%unignore xla::swig::ComputationBuilder::RngNormal;
|
||||||
%unignore xla::swig::LocalComputationBuilder::RngUniform;
|
%unignore xla::swig::ComputationBuilder::RngUniform;
|
||||||
%unignore xla::swig::LocalComputationBuilder::RngBernoulli;
|
%unignore xla::swig::ComputationBuilder::RngBernoulli;
|
||||||
%unignore xla::swig::LocalComputationBuilder::While;
|
%unignore xla::swig::ComputationBuilder::While;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Conditional;
|
%unignore xla::swig::ComputationBuilder::Conditional;
|
||||||
%unignore xla::swig::LocalComputationBuilder::IsConstant;
|
%unignore xla::swig::ComputationBuilder::IsConstant;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Eq;
|
%unignore xla::swig::ComputationBuilder::Eq;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Ne;
|
%unignore xla::swig::ComputationBuilder::Ne;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Ge;
|
%unignore xla::swig::ComputationBuilder::Ge;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Gt;
|
%unignore xla::swig::ComputationBuilder::Gt;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Lt;
|
%unignore xla::swig::ComputationBuilder::Lt;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Le;
|
%unignore xla::swig::ComputationBuilder::Le;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Dot;
|
%unignore xla::swig::ComputationBuilder::Dot;
|
||||||
%unignore xla::swig::LocalComputationBuilder::DotGeneral;
|
%unignore xla::swig::ComputationBuilder::DotGeneral;
|
||||||
%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated;
|
%unignore xla::swig::ComputationBuilder::ConvGeneralDilated;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Add;
|
%unignore xla::swig::ComputationBuilder::Add;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Sub;
|
%unignore xla::swig::ComputationBuilder::Sub;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Mul;
|
%unignore xla::swig::ComputationBuilder::Mul;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Div;
|
%unignore xla::swig::ComputationBuilder::Div;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Rem;
|
%unignore xla::swig::ComputationBuilder::Rem;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Max;
|
%unignore xla::swig::ComputationBuilder::Max;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Min;
|
%unignore xla::swig::ComputationBuilder::Min;
|
||||||
%unignore xla::swig::LocalComputationBuilder::And;
|
%unignore xla::swig::ComputationBuilder::And;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Or;
|
%unignore xla::swig::ComputationBuilder::Or;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Xor;
|
%unignore xla::swig::ComputationBuilder::Xor;
|
||||||
%unignore xla::swig::LocalComputationBuilder::ShiftLeft;
|
%unignore xla::swig::ComputationBuilder::ShiftLeft;
|
||||||
%unignore xla::swig::LocalComputationBuilder::ShiftRightArithmetic;
|
%unignore xla::swig::ComputationBuilder::ShiftRightArithmetic;
|
||||||
%unignore xla::swig::LocalComputationBuilder::ShiftRightLogical;
|
%unignore xla::swig::ComputationBuilder::ShiftRightLogical;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Not;
|
%unignore xla::swig::ComputationBuilder::Not;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Abs;
|
%unignore xla::swig::ComputationBuilder::Abs;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Exp;
|
%unignore xla::swig::ComputationBuilder::Exp;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Expm1;
|
%unignore xla::swig::ComputationBuilder::Expm1;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Floor;
|
%unignore xla::swig::ComputationBuilder::Floor;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Ceil;
|
%unignore xla::swig::ComputationBuilder::Ceil;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Round;
|
%unignore xla::swig::ComputationBuilder::Round;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Log;
|
%unignore xla::swig::ComputationBuilder::Log;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Log1p;
|
%unignore xla::swig::ComputationBuilder::Log1p;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Sign;
|
%unignore xla::swig::ComputationBuilder::Sign;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Cos;
|
%unignore xla::swig::ComputationBuilder::Cos;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Sin;
|
%unignore xla::swig::ComputationBuilder::Sin;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Tanh;
|
%unignore xla::swig::ComputationBuilder::Tanh;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Atan2;
|
%unignore xla::swig::ComputationBuilder::Atan2;
|
||||||
%unignore xla::swig::LocalComputationBuilder::IsFinite;
|
%unignore xla::swig::ComputationBuilder::IsFinite;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Pow;
|
%unignore xla::swig::ComputationBuilder::Pow;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Neg;
|
%unignore xla::swig::ComputationBuilder::Neg;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Sort;
|
%unignore xla::swig::ComputationBuilder::Sort;
|
||||||
%unignore xla::swig::LocalComputationBuilder::SortKeyVal;
|
%unignore xla::swig::ComputationBuilder::SortKeyVal;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Sqrt;
|
%unignore xla::swig::ComputationBuilder::Sqrt;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Rsqrt;
|
%unignore xla::swig::ComputationBuilder::Rsqrt;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Square;
|
%unignore xla::swig::ComputationBuilder::Square;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Reciprocal;
|
%unignore xla::swig::ComputationBuilder::Reciprocal;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Erfc;
|
%unignore xla::swig::ComputationBuilder::Erfc;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Erf;
|
%unignore xla::swig::ComputationBuilder::Erf;
|
||||||
%unignore xla::swig::LocalComputationBuilder::ErfInv;
|
%unignore xla::swig::ComputationBuilder::ErfInv;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Lgamma;
|
%unignore xla::swig::ComputationBuilder::Lgamma;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Digamma;
|
%unignore xla::swig::ComputationBuilder::Digamma;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Acos;
|
%unignore xla::swig::ComputationBuilder::Acos;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Asin;
|
%unignore xla::swig::ComputationBuilder::Asin;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Atan;
|
%unignore xla::swig::ComputationBuilder::Atan;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Tan;
|
%unignore xla::swig::ComputationBuilder::Tan;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Acosh;
|
%unignore xla::swig::ComputationBuilder::Acosh;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Asinh;
|
%unignore xla::swig::ComputationBuilder::Asinh;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Atanh;
|
%unignore xla::swig::ComputationBuilder::Atanh;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Cosh;
|
%unignore xla::swig::ComputationBuilder::Cosh;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Sinh;
|
%unignore xla::swig::ComputationBuilder::Sinh;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Real;
|
%unignore xla::swig::ComputationBuilder::Real;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Imag;
|
%unignore xla::swig::ComputationBuilder::Imag;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Conj;
|
%unignore xla::swig::ComputationBuilder::Conj;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Complex;
|
%unignore xla::swig::ComputationBuilder::Complex;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Cholesky;
|
%unignore xla::swig::ComputationBuilder::Cholesky;
|
||||||
%unignore xla::swig::LocalComputationBuilder::QR;
|
%unignore xla::swig::ComputationBuilder::QR;
|
||||||
%unignore xla::swig::LocalComputationBuilder::TriangularSolve;
|
%unignore xla::swig::ComputationBuilder::TriangularSolve;
|
||||||
%unignore xla::swig::LocalComputationBuilder::CustomCall;
|
%unignore xla::swig::ComputationBuilder::CustomCall;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Gather;
|
%unignore xla::swig::ComputationBuilder::Gather;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Scatter;
|
%unignore xla::swig::ComputationBuilder::Scatter;
|
||||||
%unignore xla::swig::DeleteLocalComputation;
|
%unignore xla::swig::DeleteComputation;
|
||||||
%unignore xla::swig::DestructureLocalShapedBufferTuple;
|
|
||||||
%unignore xla::swig::DestructureXrtAllocationTuple;
|
%unignore xla::swig::DestructureXrtAllocationTuple;
|
||||||
%unignore xla::swig::DeleteLocalShapedBuffer;
|
%unignore xla::swig::DeleteLocalShapedBuffer;
|
||||||
%unignore xla::swig::DeleteXrtAllocation;
|
%unignore xla::swig::DeleteXrtAllocation;
|
||||||
%unignore xla::swig::DeleteCompiledLocalComputation;
|
%unignore xla::swig::DeleteLocalExecutable;
|
||||||
%unignore xla::swig::DeleteCompiledXrtComputation;
|
%unignore xla::swig::DeleteXrtExecutable;
|
||||||
|
|
||||||
%thread;
|
%thread;
|
||||||
%include "tensorflow/compiler/xla/python/local_computation_builder.h"
|
%include "tensorflow/compiler/xla/python/local_computation_builder.h"
|
||||||
|
@ -127,28 +127,42 @@ bool NumpyTypeIsValid(int np_type) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* PyShapeInfoFromXlaShape(const Shape& shape) {
|
Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape) {
|
||||||
int np_typenum = PrimitiveTypeToNumpyType(shape.element_type());
|
int np_typenum = PrimitiveTypeToNumpyType(shape.element_type());
|
||||||
PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum);
|
PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum);
|
||||||
|
|
||||||
PyObject* dimensions;
|
Safe_PyObjectPtr dimensions;
|
||||||
if (shape.IsTuple()) {
|
if (shape.IsTuple()) {
|
||||||
int num_elements = ShapeUtil::TupleElementCount(shape);
|
int num_elements = ShapeUtil::TupleElementCount(shape);
|
||||||
dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape));
|
dimensions = make_safe(PyTuple_New(ShapeUtil::TupleElementCount(shape)));
|
||||||
for (int i = 0; i < num_elements; ++i) {
|
for (int i = 0; i < num_elements; ++i) {
|
||||||
PyTuple_SET_ITEM(
|
PyTuple_SET_ITEM(
|
||||||
dimensions, i,
|
dimensions.get(), i,
|
||||||
PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i)));
|
PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))
|
||||||
|
.release());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int rank = shape.rank();
|
int rank = shape.rank();
|
||||||
dimensions = PyTuple_New(rank);
|
dimensions = make_safe(PyTuple_New(rank));
|
||||||
for (int i = 0; i < rank; ++i) {
|
for (int i = 0; i < rank; ++i) {
|
||||||
PyTuple_SET_ITEM(dimensions, i,
|
PyTuple_SET_ITEM(dimensions.get(), i,
|
||||||
LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i)));
|
LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return PyTuple_Pack(2, np_dtype, dimensions);
|
return make_safe(PyTuple_Pack(2, np_dtype, dimensions.release()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape(
|
||||||
|
const ProgramShape& shape) {
|
||||||
|
Safe_PyObjectPtr arg_shapes = make_safe(PyTuple_New(shape.parameters_size()));
|
||||||
|
for (int i = 0; i < shape.parameters_size(); ++i) {
|
||||||
|
PyTuple_SET_ITEM(arg_shapes.get(), i,
|
||||||
|
PyShapeInfoFromXlaShape(shape.parameters(i)).release());
|
||||||
|
}
|
||||||
|
|
||||||
|
Safe_PyObjectPtr result_shape = PyShapeInfoFromXlaShape(shape.result());
|
||||||
|
return make_safe(
|
||||||
|
PyTuple_Pack(2, arg_shapes.release(), result_shape.release()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Precondition: o->ob_type == &PyArrayDescr_Type
|
// Precondition: o->ob_type == &PyArrayDescr_Type
|
||||||
|
@ -64,7 +64,13 @@ bool NumpyTypeIsValid(int np_type);
|
|||||||
// providing the array dimensions.
|
// providing the array dimensions.
|
||||||
//
|
//
|
||||||
// The return value is a new reference.
|
// The return value is a new reference.
|
||||||
PyObject* PyShapeInfoFromXlaShape(const Shape& shape);
|
Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape);
|
||||||
|
|
||||||
|
// Returns a pair of (arg_shapes, result_shape), where arg_shapes is a tuple
|
||||||
|
// of argument shapes and result_shape is the result shape. Each shape is as
|
||||||
|
// described in in PyShapeInfoFromXlaShape's comment.
|
||||||
|
Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape(
|
||||||
|
const ProgramShape& shape);
|
||||||
|
|
||||||
// Converts a Python object with a method interface mathing that of
|
// Converts a Python object with a method interface mathing that of
|
||||||
// xla_client.Shape into an XLA Shape object.
|
// xla_client.Shape into an XLA Shape object.
|
||||||
|
@ -36,7 +36,7 @@ from tensorflow.compiler.xla.service import hlo_pb2
|
|||||||
|
|
||||||
|
|
||||||
# Most functions are snake_case for consistency with other modules, whereas
|
# Most functions are snake_case for consistency with other modules, whereas
|
||||||
# method names of ComputationBuilder and LocalComputation are CamelCase for
|
# method names of ComputationBuilder and Computation are CamelCase for
|
||||||
# consistency with XLA.
|
# consistency with XLA.
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ from tensorflow.compiler.xla.service import hlo_pb2
|
|||||||
# which case we need to be able to detect when incompatible versions are
|
# which case we need to be able to detect when incompatible versions are
|
||||||
# installed.
|
# installed.
|
||||||
def version():
|
def version():
|
||||||
return (0, 1, 7)
|
return (0, 1, 8)
|
||||||
|
|
||||||
|
|
||||||
_OP_METADATA_FIELDS = [
|
_OP_METADATA_FIELDS = [
|
||||||
@ -66,6 +66,10 @@ OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS)
|
|||||||
class Backend(object):
|
class Backend(object):
|
||||||
"""Abstract base class for XLA backends."""
|
"""Abstract base class for XLA backends."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def device_count(self):
|
||||||
|
"""Returns the number of devices known to the backend."""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def buffer_from_pyval(self, pyval, device=0):
|
def buffer_from_pyval(self, pyval, device=0):
|
||||||
"""Allocates a fresh buffer and populates it with `pyval`."""
|
"""Allocates a fresh buffer and populates it with `pyval`."""
|
||||||
@ -95,25 +99,39 @@ class Backend(object):
|
|||||||
"""Runs an executable in a replicated manner."""
|
"""Runs an executable in a replicated manner."""
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_encode_string(s):
|
||||||
|
if six.PY3:
|
||||||
|
return s.encode('utf-8')
|
||||||
|
else:
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
class XlaLocalBackend(Backend):
|
class XlaLocalBackend(Backend):
|
||||||
"""XLA backend implemented using the in-process xla::LocalClient API."""
|
"""XLA backend implemented using the in-process xla::LocalClient API."""
|
||||||
|
|
||||||
|
def __init__(self, platform=None):
|
||||||
|
platform = platform or _get_default_platform_name()
|
||||||
|
self.client = c_api.LocalClient.Get(_maybe_encode_string(platform))
|
||||||
|
|
||||||
|
def device_count(self):
|
||||||
|
return self.client.DeviceCount()
|
||||||
|
|
||||||
def buffer_from_pyval(self, pyval, device=0):
|
def buffer_from_pyval(self, pyval, device=0):
|
||||||
return c_api.LocalShapedBuffer.FromLiteral(pyval, None, device)
|
return c_api.LocalShapedBuffer.FromLiteral(pyval, None, self.client, device)
|
||||||
|
|
||||||
def delete_buffer(self, c_buffer):
|
def delete_buffer(self, c_buffer):
|
||||||
c_api.DeleteLocalShapedBuffer(c_buffer)
|
c_api.DeleteLocalShapedBuffer(c_buffer)
|
||||||
|
|
||||||
def destructure_tuple(self, c_buffer):
|
def destructure_tuple(self, c_buffer):
|
||||||
result = c_api.DestructureLocalShapedBufferTuple(c_buffer)
|
result = c_buffer.DestructureTuple()
|
||||||
return [result.Release(i) for i in xrange(result.size())]
|
return [result.Release(i) for i in xrange(result.size())]
|
||||||
|
|
||||||
def compile(self, c_computation, argument_shapes, compile_options):
|
def compile(self, c_computation, argument_shapes, compile_options):
|
||||||
return c_computation.Compile(argument_shapes, compile_options)
|
return c_computation.Compile(argument_shapes, compile_options, self.client)
|
||||||
|
|
||||||
def delete_executable(self, executable):
|
def delete_executable(self, executable):
|
||||||
assert isinstance(executable, c_api.CompiledLocalComputation)
|
assert isinstance(executable, c_api.LocalExecutable)
|
||||||
c_api.DeleteCompiledLocalComputation(executable)
|
c_api.DeleteLocalExecutable(executable)
|
||||||
|
|
||||||
def execute(self, executable, args):
|
def execute(self, executable, args):
|
||||||
return executable.Execute(args)
|
return executable.Execute(args)
|
||||||
@ -130,6 +148,9 @@ class XrtBackend(Backend):
|
|||||||
def __init__(self, target):
|
def __init__(self, target):
|
||||||
self.target = target
|
self.target = target
|
||||||
|
|
||||||
|
def device_count(self):
|
||||||
|
return 1 # Multidevice execution not implemented.
|
||||||
|
|
||||||
def buffer_from_pyval(self, pyval, device=0):
|
def buffer_from_pyval(self, pyval, device=0):
|
||||||
if device != 0:
|
if device != 0:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -150,8 +171,8 @@ class XrtBackend(Backend):
|
|||||||
_maybe_encode_string(self.target))
|
_maybe_encode_string(self.target))
|
||||||
|
|
||||||
def delete_executable(self, executable):
|
def delete_executable(self, executable):
|
||||||
assert isinstance(executable, c_api.CompiledXrtComputation)
|
assert isinstance(executable, c_api.XrtExecutable)
|
||||||
c_api.DeleteCompiledXrtComputation(executable)
|
c_api.DeleteXrtExecutable(executable)
|
||||||
|
|
||||||
def execute(self, executable, args):
|
def execute(self, executable, args):
|
||||||
return executable.Execute(args)
|
return executable.Execute(args)
|
||||||
@ -163,7 +184,20 @@ class XrtBackend(Backend):
|
|||||||
return [executable.Execute(per_replica_args[0])]
|
return [executable.Execute(per_replica_args[0])]
|
||||||
|
|
||||||
|
|
||||||
XLA_LOCAL_BACKEND = XlaLocalBackend()
|
_default_platform_name = 'Host'
|
||||||
|
_default_backend = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_platform_name():
|
||||||
|
return _default_platform_name
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_local_backend():
|
||||||
|
global _default_backend
|
||||||
|
global _default_platform_name
|
||||||
|
if _default_backend is None:
|
||||||
|
_default_backend = XlaLocalBackend(_default_platform_name)
|
||||||
|
return _default_backend
|
||||||
|
|
||||||
|
|
||||||
class BackendType(enum.Enum):
|
class BackendType(enum.Enum):
|
||||||
@ -174,7 +208,7 @@ class BackendType(enum.Enum):
|
|||||||
def BackendSpec(backend, target):
|
def BackendSpec(backend, target):
|
||||||
"""Compatibility wrapper to support older clients. Do not use in new code."""
|
"""Compatibility wrapper to support older clients. Do not use in new code."""
|
||||||
if backend == BackendType.XLA_LOCAL:
|
if backend == BackendType.XLA_LOCAL:
|
||||||
return XLA_LOCAL_BACKEND
|
return _get_default_local_backend()
|
||||||
elif backend == BackendType.XRT:
|
elif backend == BackendType.XRT:
|
||||||
return XrtBackend(target)
|
return XrtBackend(target)
|
||||||
else:
|
else:
|
||||||
@ -201,13 +235,6 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
|
|||||||
source_line=lineno)
|
source_line=lineno)
|
||||||
|
|
||||||
|
|
||||||
def _maybe_encode_string(s):
|
|
||||||
if six.PY3:
|
|
||||||
return s.encode('utf-8')
|
|
||||||
else:
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
class PaddingType(enum.Enum):
|
class PaddingType(enum.Enum):
|
||||||
VALID = 1
|
VALID = 1
|
||||||
SAME = 2
|
SAME = 2
|
||||||
@ -346,22 +373,18 @@ class LocalBuffer(object):
|
|||||||
means the referent is in device memory.
|
means the referent is in device memory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, c_buffer, backend, replica):
|
def __init__(self, c_buffer, backend, device):
|
||||||
self.c_buffer = c_buffer
|
self.c_buffer = c_buffer
|
||||||
self._backend = backend
|
self._backend = backend
|
||||||
self._replica = replica
|
self._device = device
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pyval(pyval, replica=0, backend=XLA_LOCAL_BACKEND):
|
def from_pyval(pyval, device=0, backend=None):
|
||||||
"""Allocate and copy to XLA the given python value."""
|
"""Allocate and copy to XLA the given python value."""
|
||||||
|
backend = backend or _get_default_local_backend()
|
||||||
pyval = require_numpy_array_layout(pyval)
|
pyval = require_numpy_array_layout(pyval)
|
||||||
num_replicas = get_replica_count()
|
cbuf = backend.buffer_from_pyval(pyval, device)
|
||||||
if not 0 <= replica < num_replicas:
|
return LocalBuffer(cbuf, backend, device)
|
||||||
raise ValueError(
|
|
||||||
'Attempt to place buffer on replica {} when the replica count is {}'
|
|
||||||
.format(replica, num_replicas))
|
|
||||||
cbuf = backend.buffer_from_pyval(pyval, replica)
|
|
||||||
return LocalBuffer(cbuf, backend, replica)
|
|
||||||
|
|
||||||
def to_py(self):
|
def to_py(self):
|
||||||
return self.c_buffer.ToLiteral()
|
return self.c_buffer.ToLiteral()
|
||||||
@ -369,8 +392,8 @@ class LocalBuffer(object):
|
|||||||
def shape(self):
|
def shape(self):
|
||||||
return _wrap_shape(self.c_buffer.shape())
|
return _wrap_shape(self.c_buffer.shape())
|
||||||
|
|
||||||
def replica(self):
|
def device(self):
|
||||||
return self._replica
|
return self._device
|
||||||
|
|
||||||
def delete(self):
|
def delete(self):
|
||||||
if self.c_buffer is not None:
|
if self.c_buffer is not None:
|
||||||
@ -383,7 +406,7 @@ class LocalBuffer(object):
|
|||||||
result = self._backend.destructure_tuple(self.c_buffer)
|
result = self._backend.destructure_tuple(self.c_buffer)
|
||||||
self.delete()
|
self.delete()
|
||||||
return tuple(
|
return tuple(
|
||||||
LocalBuffer(sub_buffer, replica=self._replica, backend=self._backend)
|
LocalBuffer(sub_buffer, device=self._device, backend=self._backend)
|
||||||
for sub_buffer in result)
|
for sub_buffer in result)
|
||||||
|
|
||||||
def is_deleted(self):
|
def is_deleted(self):
|
||||||
@ -533,6 +556,16 @@ class Shape(object):
|
|||||||
updated._check_minor_to_major() # pylint: disable=protected-access
|
updated._check_minor_to_major() # pylint: disable=protected-access
|
||||||
return updated
|
return updated
|
||||||
|
|
||||||
|
def with_major_to_minor_layout_if_absent(self):
|
||||||
|
"""Returns a copy of a shape with missing layouts set to major-to-minor."""
|
||||||
|
|
||||||
|
def f(a):
|
||||||
|
if a.minor_to_major():
|
||||||
|
return None
|
||||||
|
return a.update_minor_to_major(tuple(xrange(a.rank() - 1, -1, -1)))
|
||||||
|
|
||||||
|
return self.map_leaves(f)
|
||||||
|
|
||||||
def serialize(self, proto):
|
def serialize(self, proto):
|
||||||
"""Serializes 'shape' into proto."""
|
"""Serializes 'shape' into proto."""
|
||||||
if self.is_tuple():
|
if self.is_tuple():
|
||||||
@ -548,6 +581,10 @@ class Shape(object):
|
|||||||
proto.layout.minor_to_major.extend(self.minor_to_major())
|
proto.layout.minor_to_major.extend(self.minor_to_major())
|
||||||
|
|
||||||
|
|
||||||
|
ProgramShape = collections.namedtuple('ProgramShape',
|
||||||
|
('parameter_shapes', 'result_shape'))
|
||||||
|
|
||||||
|
|
||||||
def _wrap_shape(shape_info):
|
def _wrap_shape(shape_info):
|
||||||
dtype, dims = shape_info
|
dtype, dims = shape_info
|
||||||
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)]
|
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)]
|
||||||
@ -581,7 +618,7 @@ class CompileOptions(object):
|
|||||||
self.num_replicas = get_replica_count()
|
self.num_replicas = get_replica_count()
|
||||||
|
|
||||||
|
|
||||||
def transfer_to_infeed(value, replica_number=None):
|
def transfer_to_infeed(value, device_ordinal=0):
|
||||||
"""Transfers the given value into the XLA infeed queue.
|
"""Transfers the given value into the XLA infeed queue.
|
||||||
|
|
||||||
XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with
|
XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with
|
||||||
@ -591,52 +628,49 @@ def transfer_to_infeed(value, replica_number=None):
|
|||||||
Args:
|
Args:
|
||||||
value: the value that the caller would like to enqueue into the XLA infeed
|
value: the value that the caller would like to enqueue into the XLA infeed
|
||||||
queue
|
queue
|
||||||
replica_number: the replica number to infeed the value to -- if not
|
device_ordinal: the device to infeed the value to. Each device has a
|
||||||
provided, then the default replica (trivially replica 0) is used.
|
distinct infeed queue.
|
||||||
"""
|
"""
|
||||||
if replica_number is None:
|
# TODO(phawkins): support non-default backends.
|
||||||
c_api.TransferToInfeedLocal(require_numpy_array_layout(value))
|
backend = _get_default_local_backend()
|
||||||
else:
|
backend.client.TransferToInfeed(
|
||||||
c_api.TransferToInfeedLocalReplica(
|
require_numpy_array_layout(value), device_ordinal)
|
||||||
require_numpy_array_layout(value), replica_number)
|
|
||||||
|
|
||||||
|
|
||||||
def transfer_from_outfeed(shape, replica_number=None):
|
def transfer_from_outfeed(shape, device_ordinal=0):
|
||||||
"""Transfers a literal of the given shape from replica_number's outfeed.
|
"""Transfers a literal of the given shape from `device_ordinal`'s outfeed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape: The shape of the value to transfer from outfeed.
|
shape: The shape of the value to transfer from outfeed.
|
||||||
replica_number: The replica number ordinal to transfer the outfeed value
|
device_ordinal: The device ordinal to transfer the outfeed value from. Each
|
||||||
from. (Each replica has a distinct outfeed queue.)
|
device has a distinct outfeed queue..
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The literal value that is produced from the outfeed queue.
|
The literal value that is produced from the outfeed queue.
|
||||||
"""
|
"""
|
||||||
return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0)
|
# TODO(phawkins): support non-default backends.
|
||||||
|
backend = _get_default_local_backend()
|
||||||
|
return backend.client.TransferFromOutfeed(shape, device_ordinal)
|
||||||
|
|
||||||
|
|
||||||
class LocalComputation(object):
|
class Computation(object):
|
||||||
"""Python wrapper for a local XLA Computation.
|
"""Python wrapper for an XLA Computation.
|
||||||
|
|
||||||
A LocalComputation can be executed if it is compiled. Otherwise, it
|
A Computation can be compiled to form an Executable, or used as a
|
||||||
can still be used as a Computation where required by the
|
subcomputation in ComputationBuilder methods.
|
||||||
ComputationBuilder methods.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, c_computation, is_compiled, backend=XLA_LOCAL_BACKEND):
|
def __init__(self, c_computation, backend=None):
|
||||||
self._c_computation = c_computation
|
self._c_computation = c_computation
|
||||||
|
# The backend argument is deprecated. Pass a backend to Compile() instead.
|
||||||
self._backend = backend
|
self._backend = backend
|
||||||
self._is_compiled = is_compiled
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def computation(self):
|
def computation(self):
|
||||||
if self._is_compiled:
|
|
||||||
raise ValueError(
|
|
||||||
'Attempt to read the XLA computation of a compiled LocalComputation.')
|
|
||||||
return self._c_computation
|
return self._c_computation
|
||||||
|
|
||||||
def GetProto(self):
|
def GetProto(self):
|
||||||
"""Get the HloModuleProto proto object in this local computation.
|
"""Get the HloModuleProto proto object in this computation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An HloModuleProto proto object that has the whole-graph information.
|
An HloModuleProto proto object that has the whole-graph information.
|
||||||
@ -645,30 +679,25 @@ class LocalComputation(object):
|
|||||||
proto = hlo_pb2.HloModuleProto.FromString(serialized)
|
proto = hlo_pb2.HloModuleProto.FromString(serialized)
|
||||||
return proto
|
return proto
|
||||||
|
|
||||||
def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None):
|
def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None,
|
||||||
"""Compiles an un-compiled local computation.
|
backend=None):
|
||||||
|
"""Compiles a computation.
|
||||||
|
|
||||||
Local computations are the result of a "LocalComputationBuild'ing" process
|
Computations are the result of a "ComputationBuild'ing" process.
|
||||||
-- they start in uncompiled form, and via a call to Compile() turn into a
|
|
||||||
compiled local computation.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if this is already a compiled local computation.
|
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
argument_shapes: parameter shapes -- they are first laid out by layout_fn
|
argument_shapes: parameter shapes -- they are first laid out by layout_fn
|
||||||
if layout_fn is provided. Otherwise, the default layout for those shapes
|
if layout_fn is provided. Otherwise, the default layout for those shapes
|
||||||
will be used.
|
will be used.
|
||||||
compile_options: options to use for compilation, includes an optional
|
compile_options: options to use for compilation, includes an optional laid
|
||||||
laid out result shape for the computation.
|
out result shape for the computation.
|
||||||
layout_fn: lambda that is used to lay out the argument/result shapes.
|
layout_fn: lambda that is used to lay out the argument/result shapes.
|
||||||
|
backend: a `Backend` for which an executable should be generated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A newly *compiled* local computation instance.
|
A Executable instance.
|
||||||
"""
|
"""
|
||||||
if self._is_compiled:
|
backend = backend or self._backend or _get_default_local_backend()
|
||||||
raise ValueError('Attempt to compile a compiled local XLA computation.')
|
|
||||||
|
|
||||||
result_shape = _wrap_shape(self.computation.GetReturnValueShape())
|
result_shape = _wrap_shape(self.computation.GetReturnValueShape())
|
||||||
|
|
||||||
if layout_fn:
|
if layout_fn:
|
||||||
@ -681,29 +710,55 @@ class LocalComputation(object):
|
|||||||
|
|
||||||
compile_options = compile_options or CompileOptions()
|
compile_options = compile_options or CompileOptions()
|
||||||
compile_options.result_shape = result_shape
|
compile_options.result_shape = result_shape
|
||||||
c = self._backend.compile(self.computation, argument_shapes,
|
c = backend.compile(self.computation, argument_shapes, compile_options)
|
||||||
compile_options)
|
return Executable(c, backend=backend)
|
||||||
return LocalComputation(c, is_compiled=True, backend=self._backend)
|
|
||||||
|
|
||||||
def CompileWithExampleArguments(self,
|
def CompileWithExampleArguments(self,
|
||||||
arguments=(),
|
arguments=(),
|
||||||
compile_options=None,
|
compile_options=None,
|
||||||
layout_fn=None):
|
layout_fn=None,
|
||||||
|
backend=None):
|
||||||
return self.Compile(
|
return self.Compile(
|
||||||
argument_shapes=[Shape.from_pyval(arg) for arg in arguments],
|
argument_shapes=[Shape.from_pyval(arg) for arg in arguments],
|
||||||
compile_options=compile_options,
|
compile_options=compile_options,
|
||||||
layout_fn=layout_fn)
|
layout_fn=layout_fn,
|
||||||
|
backend=backend)
|
||||||
|
|
||||||
|
def GetProgramShape(self):
|
||||||
|
(arg_shapes, result_shape) = self._c_computation.GetProgramShape()
|
||||||
|
return ProgramShape([_wrap_shape(arg) for arg in arg_shapes],
|
||||||
|
_wrap_shape(result_shape))
|
||||||
|
|
||||||
def GetReturnValueShape(self):
|
def GetReturnValueShape(self):
|
||||||
return _wrap_shape(self._c_computation.GetReturnValueShape())
|
return _wrap_shape(self._c_computation.GetReturnValueShape())
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
# Python may have freed c_api first.
|
||||||
|
if c_api and self._c_computation:
|
||||||
|
assert isinstance(self._c_computation, c_api.Computation)
|
||||||
|
c_api.DeleteComputation(self._c_computation)
|
||||||
|
|
||||||
|
|
||||||
|
class Executable(object):
|
||||||
|
"""Python wrapper for an XLA Executable."""
|
||||||
|
|
||||||
|
def __init__(self, c_executable, backend=None):
|
||||||
|
self._c_executable = c_executable
|
||||||
|
self._device_ordinals = c_executable.DeviceOrdinals()
|
||||||
|
self._backend = backend
|
||||||
|
|
||||||
|
def DeviceOrdinals(self):
|
||||||
|
"""Returns a list containing the device ordinals for each replica."""
|
||||||
|
return self._device_ordinals
|
||||||
|
|
||||||
def Execute(self, arguments=(), check_for_deleted_args=True):
|
def Execute(self, arguments=(), check_for_deleted_args=True):
|
||||||
"""Execute on one replica with LocalBuffer arguments and return value."""
|
"""Execute on one replica with LocalBuffer arguments and return value."""
|
||||||
if check_for_deleted_args and any(arg.is_deleted() for arg in arguments):
|
if check_for_deleted_args and any(arg.is_deleted() for arg in arguments):
|
||||||
raise ValueError('Executing with deleted local buffer argument')
|
raise ValueError('Executing with deleted local buffer argument')
|
||||||
raw_args = [arg.c_buffer for arg in arguments]
|
raw_args = [arg.c_buffer for arg in arguments]
|
||||||
output_buffer = self._backend.execute(self._c_computation, raw_args)
|
output_buffer = self._backend.execute(self._c_executable, raw_args)
|
||||||
return LocalBuffer(output_buffer, backend=self._backend, replica=0)
|
return LocalBuffer(
|
||||||
|
output_buffer, backend=self._backend, device=self._device_ordinals[0])
|
||||||
|
|
||||||
def ExecutePerReplica(self, arguments=None):
|
def ExecutePerReplica(self, arguments=None):
|
||||||
"""Execute on many replicas with LocalBuffer arguments and return value.
|
"""Execute on many replicas with LocalBuffer arguments and return value.
|
||||||
@ -713,14 +768,12 @@ class LocalComputation(object):
|
|||||||
sequence comprises the arguments for execution on the i'th replica.
|
sequence comprises the arguments for execution on the i'th replica.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of the computation's outputs on each replica, as a LocalBuffer. If
|
A list of the computation's outputs for each replica, as a LocalBuffer. If
|
||||||
a shallow sequence of arguments was passed in for `arguments`, then the
|
a shallow sequence of arguments was passed in for `arguments`, then the
|
||||||
sole, zero'th replica's output is returned instead, as a LocalBuffer.
|
sole, zero'th replica's output is returned instead, as a LocalBuffer.
|
||||||
"""
|
"""
|
||||||
if not self._is_compiled:
|
|
||||||
raise ValueError('Cannot execute an uncompiled local XLA computation.')
|
|
||||||
if arguments is None:
|
if arguments is None:
|
||||||
arguments = ((),) * get_replica_count()
|
arguments = ((),) * len(self._device_ordinals)
|
||||||
else:
|
else:
|
||||||
arguments = [list(replica_args) for replica_args in arguments]
|
arguments = [list(replica_args) for replica_args in arguments]
|
||||||
|
|
||||||
@ -729,30 +782,35 @@ class LocalComputation(object):
|
|||||||
for arg in replica_args:
|
for arg in replica_args:
|
||||||
if arg.is_deleted():
|
if arg.is_deleted():
|
||||||
raise ValueError('Executing with deleted local buffer argument')
|
raise ValueError('Executing with deleted local buffer argument')
|
||||||
if arg.replica() != replica:
|
if arg.device() != self._device_ordinals[replica]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Executing on replica {} with argument from replica {}'.format(
|
'Executing on device {} with argument from device {}'.format(
|
||||||
replica, arg.replica()))
|
self._device_ordinals[replica], arg.device()))
|
||||||
|
|
||||||
# Pull out argument buffer handles
|
# Pull out argument buffer handles
|
||||||
|
# pylint: disable=g-complex-comprehension
|
||||||
stripped_args = [
|
stripped_args = [
|
||||||
[arg.c_buffer for arg in replica_args] for replica_args in arguments
|
[arg.c_buffer for arg in replica_args] for replica_args in arguments
|
||||||
]
|
]
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
output_buffers = self._backend.execute_replicated(
|
output_buffers = self._backend.execute_replicated(self._c_executable,
|
||||||
self._c_computation, stripped_args)
|
stripped_args)
|
||||||
|
|
||||||
# Wrap output handles in LocalBuffer instances
|
# Wrap output handles in LocalBuffer instances
|
||||||
return tuple(
|
return tuple(
|
||||||
LocalBuffer(output_buffer, backend=self._backend, replica=replica)
|
LocalBuffer(
|
||||||
|
output_buffer,
|
||||||
|
backend=self._backend,
|
||||||
|
device=self._device_ordinals[replica])
|
||||||
for replica, output_buffer in enumerate(output_buffers))
|
for replica, output_buffer in enumerate(output_buffers))
|
||||||
|
|
||||||
def ExecuteWithPythonValues(self, arguments=()):
|
def ExecuteWithPythonValues(self, arguments=()):
|
||||||
"""Execute on one replica with Python values as arguments and output."""
|
"""Execute on one replica with Python values as arguments and output."""
|
||||||
|
|
||||||
def put(arg):
|
def put(arg):
|
||||||
return LocalBuffer.from_pyval(arg, backend=self._backend)
|
return LocalBuffer.from_pyval(
|
||||||
|
arg, device=self._device_ordinals[0], backend=self._backend)
|
||||||
|
|
||||||
arguments = [put(arg) for arg in arguments]
|
arguments = [put(arg) for arg in arguments]
|
||||||
return self.Execute(arguments).to_py()
|
return self.Execute(arguments).to_py()
|
||||||
@ -760,22 +818,19 @@ class LocalComputation(object):
|
|||||||
def ExecuteWithPythonValuesPerReplica(self, arguments):
|
def ExecuteWithPythonValuesPerReplica(self, arguments):
|
||||||
"""Execute on many replicas with Python values as arguments and output."""
|
"""Execute on many replicas with Python values as arguments and output."""
|
||||||
|
|
||||||
def put(arg, replica):
|
def put(arg, device):
|
||||||
return LocalBuffer.from_pyval(arg, replica, backend=self._backend)
|
return LocalBuffer.from_pyval(arg, device, backend=self._backend)
|
||||||
|
|
||||||
arguments = [[put(arg, replica)
|
# pylint: disable=g-complex-comprehension
|
||||||
for arg in replica_args]
|
arguments = [[
|
||||||
for replica, replica_args in enumerate(arguments)]
|
put(arg, self._device_ordinals[replica]) for arg in replica_args
|
||||||
|
] for replica, replica_args in enumerate(arguments)]
|
||||||
return [out.to_py() for out in self.ExecutePerReplica(arguments)]
|
return [out.to_py() for out in self.ExecutePerReplica(arguments)]
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
# Python may have freed c_api first.
|
# Python may have freed c_api first.
|
||||||
if c_api and self._c_computation:
|
if c_api and self._c_executable:
|
||||||
if self._is_compiled:
|
self._backend.delete_executable(self._c_executable)
|
||||||
self._backend.delete_executable(self._c_computation)
|
|
||||||
else:
|
|
||||||
assert isinstance(self._c_computation, c_api.LocalComputation)
|
|
||||||
c_api.DeleteLocalComputation(self._c_computation)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_replica_group_proto(replica_group):
|
def _make_replica_group_proto(replica_group):
|
||||||
@ -788,8 +843,8 @@ class ComputationBuilder(object):
|
|||||||
"""XLA computation builder.
|
"""XLA computation builder.
|
||||||
|
|
||||||
Enqueues XLA ops in sequence and in order to build a
|
Enqueues XLA ops in sequence and in order to build a
|
||||||
LocalComputation, which in turn can be compiled into a
|
Computation, which in turn can be compiled into a
|
||||||
CompiledLocalComputation, which in turn can be locally executed.
|
LocalExecutable, which in turn can be locally executed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# The methods of this class map 1-to-1 onto the XLA C++
|
# The methods of this class map 1-to-1 onto the XLA C++
|
||||||
@ -800,16 +855,23 @@ class ComputationBuilder(object):
|
|||||||
# pylint: disable=g-doc-args
|
# pylint: disable=g-doc-args
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
self._client = c_api.LocalComputationBuilder(name.encode('utf8'))
|
self._client = c_api.ComputationBuilder(name.encode('utf8'))
|
||||||
self._parameter_numbering = itertools.count()
|
self._parameter_numbering = itertools.count()
|
||||||
|
|
||||||
def Build(self, root=None, backend=XLA_LOCAL_BACKEND):
|
def Build(self, root=None, backend=None):
|
||||||
|
"""Builds a `Computation` from the contents of the builder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root: if not None, the operator containing the return value of the
|
||||||
|
computation.
|
||||||
|
backend: deprecated. Pass a `backend` to `Computation.Compile` instead.
|
||||||
|
Returns:
|
||||||
|
A `Computation`.
|
||||||
|
"""
|
||||||
if root is not None:
|
if root is not None:
|
||||||
return LocalComputation(
|
return Computation(self._client.BuildWithRoot(root), backend=backend)
|
||||||
self._client.BuildWithRoot(root), is_compiled=False, backend=backend)
|
|
||||||
else:
|
else:
|
||||||
return LocalComputation(
|
return Computation(self._client.Build(), backend=backend)
|
||||||
self._client.Build(), is_compiled=False, backend=backend)
|
|
||||||
|
|
||||||
def SetOpMetadata(self, op_metadata):
|
def SetOpMetadata(self, op_metadata):
|
||||||
"""Set metadata for operations that are about to be enqueued."""
|
"""Set metadata for operations that are about to be enqueued."""
|
||||||
@ -1461,7 +1523,7 @@ class ComputationBuilder(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
operand: a LocalOp to test.
|
operand: a LocalOp to test.
|
||||||
Returns: a LocalComputation that is rooted on the given `operand` which is a
|
Returns: a Computation that is rooted on the given `operand` which is a
|
||||||
compile-time constant.
|
compile-time constant.
|
||||||
"""
|
"""
|
||||||
return self._client.BuildConstantSubGraph(operand)
|
return self._client.BuildConstantSubGraph(operand)
|
||||||
@ -1662,7 +1724,7 @@ def _forward_methods_to_local_builder():
|
|||||||
|
|
||||||
Set up methods, corresponding to unary and binary XLA operations,
|
Set up methods, corresponding to unary and binary XLA operations,
|
||||||
whose calls are forwarded in a boilerplate manner to the underlying
|
whose calls are forwarded in a boilerplate manner to the underlying
|
||||||
LocalComputationBuilder C-extension API.
|
ComputationBuilder C-extension API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward_to_local_builder_with_handles(target_method, is_binop=False):
|
def forward_to_local_builder_with_handles(target_method, is_binop=False):
|
||||||
@ -1682,13 +1744,13 @@ def _forward_methods_to_local_builder():
|
|||||||
|
|
||||||
for method_name in _UNARY_OPS:
|
for method_name in _UNARY_OPS:
|
||||||
forward = forward_to_local_builder_with_handles(
|
forward = forward_to_local_builder_with_handles(
|
||||||
getattr(c_api.LocalComputationBuilder, method_name))
|
getattr(c_api.ComputationBuilder, method_name))
|
||||||
forward.__name__ = method_name
|
forward.__name__ = method_name
|
||||||
setattr(ComputationBuilder, method_name, forward)
|
setattr(ComputationBuilder, method_name, forward)
|
||||||
|
|
||||||
for method_name in _BINARY_OPS:
|
for method_name in _BINARY_OPS:
|
||||||
forward = forward_to_local_builder_with_handles(
|
forward = forward_to_local_builder_with_handles(
|
||||||
getattr(c_api.LocalComputationBuilder, method_name), is_binop=True)
|
getattr(c_api.ComputationBuilder, method_name), is_binop=True)
|
||||||
forward.__name__ = method_name
|
forward.__name__ = method_name
|
||||||
setattr(ComputationBuilder, method_name, forward)
|
setattr(ComputationBuilder, method_name, forward)
|
||||||
|
|
||||||
@ -1696,8 +1758,14 @@ def _forward_methods_to_local_builder():
|
|||||||
_forward_methods_to_local_builder()
|
_forward_methods_to_local_builder()
|
||||||
|
|
||||||
|
|
||||||
|
_default_replica_count = 1
|
||||||
|
|
||||||
|
|
||||||
def initialize_replica_count(replica_count):
|
def initialize_replica_count(replica_count):
|
||||||
"""Initializes the desired replica count to use on XLA service init.
|
"""Initializes the default replica count to use.
|
||||||
|
|
||||||
|
Deprecated; pass `num_replicas` as an option to `Computation.Compile()`
|
||||||
|
instead.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
replica_count: number of replicas that are desired for set up during XLA
|
replica_count: number of replicas that are desired for set up during XLA
|
||||||
@ -1706,31 +1774,30 @@ def initialize_replica_count(replica_count):
|
|||||||
Raises:
|
Raises:
|
||||||
A runtime exception if the XLA service has already been initialized.
|
A runtime exception if the XLA service has already been initialized.
|
||||||
"""
|
"""
|
||||||
c_api.InitializeReplicaCount(replica_count)
|
global _default_replica_count
|
||||||
|
_default_replica_count = replica_count
|
||||||
|
|
||||||
def initialize_platform_name(platform_name):
|
|
||||||
"""Initializes the desired platform name to use on XLA service init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
platform_name: string name of platform.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
A runtime exception if the XLA service has already been initialized.
|
|
||||||
A runtime exception if the platform does not exist, or there are no devices
|
|
||||||
with that platform.
|
|
||||||
"""
|
|
||||||
platform_name = _maybe_encode_string(platform_name)
|
|
||||||
c_api.InitializePlatformName(platform_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_replica_count():
|
def get_replica_count():
|
||||||
"""Returns the current replica count used for the XLA service.
|
"""Returns the default replica count.
|
||||||
|
|
||||||
Note: this will return a value whether the XLA service has been initialized
|
Deprecated; pass `num_replicas` as an option to `Computation.Compile()`
|
||||||
yet or not.
|
instead.
|
||||||
"""
|
"""
|
||||||
return c_api.GetReplicaCount()
|
return _default_replica_count
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_platform_name(platform_name):
|
||||||
|
"""Initializes the default platform name to use for XLA.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform_name: string name of platform.
|
||||||
|
"""
|
||||||
|
global _default_platform_name
|
||||||
|
_default_platform_name = platform_name
|
||||||
|
|
||||||
|
# Make sure the platform is valid by trying to instantiate it.
|
||||||
|
_get_default_local_backend()
|
||||||
|
|
||||||
|
|
||||||
def register_cpu_custom_call_target(name, fn):
|
def register_cpu_custom_call_target(name, fn):
|
||||||
|
@ -29,7 +29,7 @@ from tensorflow.compiler.xla.python import xla_client
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
class LocalComputationTest(unittest.TestCase):
|
class ComputationTest(unittest.TestCase):
|
||||||
"""Base class for running an XLA Computation through the local client."""
|
"""Base class for running an XLA Computation through the local client."""
|
||||||
|
|
||||||
def _NewComputation(self, name=None):
|
def _NewComputation(self, name=None):
|
||||||
@ -85,7 +85,7 @@ def NumpyArrayBool(*args, **kwargs):
|
|||||||
return np.array(*args, dtype=np.bool, **kwargs)
|
return np.array(*args, dtype=np.bool, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ComputationsWithConstantsTest(LocalComputationTest):
|
class ComputationsWithConstantsTest(ComputationTest):
|
||||||
"""Tests focusing on Constant ops."""
|
"""Tests focusing on Constant ops."""
|
||||||
|
|
||||||
def testConstantScalarSumS8(self):
|
def testConstantScalarSumS8(self):
|
||||||
@ -304,7 +304,7 @@ class ComputationsWithConstantsTest(LocalComputationTest):
|
|||||||
self._ExecuteAndCompareClose(c, expected=0.75)
|
self._ExecuteAndCompareClose(c, expected=0.75)
|
||||||
|
|
||||||
|
|
||||||
class ParametersTest(LocalComputationTest):
|
class ParametersTest(ComputationTest):
|
||||||
"""Tests focusing on Parameter ops and argument-passing."""
|
"""Tests focusing on Parameter ops and argument-passing."""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -384,7 +384,7 @@ class ParametersTest(LocalComputationTest):
|
|||||||
expected=[-4.3, 1.3, -6.3, 3.3])
|
expected=[-4.3, 1.3, -6.3, 3.3])
|
||||||
|
|
||||||
|
|
||||||
class LocalBufferTest(LocalComputationTest):
|
class LocalBufferTest(ComputationTest):
|
||||||
"""Tests focusing on execution with LocalBuffers."""
|
"""Tests focusing on execution with LocalBuffers."""
|
||||||
|
|
||||||
def _Execute(self, c, arguments):
|
def _Execute(self, c, arguments):
|
||||||
@ -482,7 +482,7 @@ class LocalBufferTest(LocalComputationTest):
|
|||||||
self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
|
self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
|
||||||
|
|
||||||
|
|
||||||
class SingleOpTest(LocalComputationTest):
|
class SingleOpTest(ComputationTest):
|
||||||
"""Tests for single ops.
|
"""Tests for single ops.
|
||||||
|
|
||||||
The goal here is smoke testing - to exercise the most basic functionality of
|
The goal here is smoke testing - to exercise the most basic functionality of
|
||||||
@ -1175,7 +1175,7 @@ class SingleOpTest(LocalComputationTest):
|
|||||||
np.testing.assert_allclose(g, expected, rtol=1e-4)
|
np.testing.assert_allclose(g, expected, rtol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddedComputationsTest(LocalComputationTest):
|
class EmbeddedComputationsTest(ComputationTest):
|
||||||
"""Tests for XLA graphs with embedded computations (such as maps)."""
|
"""Tests for XLA graphs with embedded computations (such as maps)."""
|
||||||
|
|
||||||
def _CreateConstantS32Computation(self):
|
def _CreateConstantS32Computation(self):
|
||||||
@ -1639,7 +1639,7 @@ class EmbeddedComputationsTest(LocalComputationTest):
|
|||||||
self._ExecuteAndCompareClose(c, expected=expected)
|
self._ExecuteAndCompareClose(c, expected=expected)
|
||||||
|
|
||||||
|
|
||||||
class ErrorTest(LocalComputationTest):
|
class ErrorTest(ComputationTest):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.f32_scalar_2 = NumpyArrayF32(2.0)
|
self.f32_scalar_2 = NumpyArrayF32(2.0)
|
||||||
@ -1656,7 +1656,7 @@ class ErrorTest(LocalComputationTest):
|
|||||||
lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2]))
|
lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2]))
|
||||||
|
|
||||||
|
|
||||||
class ComputationRootTest(LocalComputationTest):
|
class ComputationRootTest(ComputationTest):
|
||||||
"""Tests related to setting the root of the computation."""
|
"""Tests related to setting the root of the computation."""
|
||||||
|
|
||||||
def testComputationRootDifferentFromLastOp(self):
|
def testComputationRootDifferentFromLastOp(self):
|
||||||
|
@ -3529,6 +3529,37 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "stable_sort_expander",
|
||||||
|
srcs = ["stable_sort_expander.cc"],
|
||||||
|
hdrs = ["stable_sort_expander.h"],
|
||||||
|
deps = [
|
||||||
|
":hlo",
|
||||||
|
":hlo_casting_utils",
|
||||||
|
":hlo_pass",
|
||||||
|
":op_expander_pass",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "stable_sort_expander_test",
|
||||||
|
srcs = ["stable_sort_expander_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":algebraic_simplifier",
|
||||||
|
":hlo_matchers",
|
||||||
|
":hlo_parser",
|
||||||
|
":pattern_matcher",
|
||||||
|
":pattern_matcher_gmock",
|
||||||
|
":stable_sort_expander",
|
||||||
|
"//tensorflow/compiler/xla:test",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tuple_util",
|
name = "tuple_util",
|
||||||
srcs = ["tuple_util.cc"],
|
srcs = ["tuple_util.cc"],
|
||||||
|
@ -280,15 +280,51 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
|||||||
hlo));
|
hlo));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper method to perform and add reduction in a single dimension.
|
// Converts to primitive type if the input hlo is not that type, otherwise
|
||||||
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
|
// returns the original hlo.
|
||||||
|
HloInstruction* AsType(HloInstruction* hlo,
|
||||||
|
const PrimitiveType element_type) {
|
||||||
|
if (hlo->shape().element_type() == element_type) {
|
||||||
|
return hlo;
|
||||||
|
}
|
||||||
|
return computation_->AddInstruction(HloInstruction::CreateConvert(
|
||||||
|
ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transposes a dot operand such that the batch dimensions are the msot major,
|
||||||
|
// and the contracting dimensions are most minor.
|
||||||
|
StatusOr<HloInstruction*> NormalizeDotOperandToBatchMajorAndContractingMinor(
|
||||||
|
HloInstruction* dot_operand, absl::Span<const int64> batch_dimensions,
|
||||||
|
absl::Span<const int64> contracting_dimensions) {
|
||||||
|
std::vector<int64> transpose_dimensions(batch_dimensions.begin(),
|
||||||
|
batch_dimensions.end());
|
||||||
|
for (int64 i = 0; i < dot_operand->shape().rank(); ++i) {
|
||||||
|
if (!(absl::c_linear_search(batch_dimensions, i) ||
|
||||||
|
absl::c_linear_search(contracting_dimensions, i))) {
|
||||||
|
transpose_dimensions.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
transpose_dimensions.insert(transpose_dimensions.end(),
|
||||||
|
contracting_dimensions.begin(),
|
||||||
|
contracting_dimensions.end());
|
||||||
|
return MakeTransposeHlo(dot_operand, transpose_dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper method to perform and add reduction on a list of dimensions.
|
||||||
|
HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64> dims) {
|
||||||
HloInstruction* zero =
|
HloInstruction* zero =
|
||||||
computation_->AddInstruction(HloInstruction::CreateConstant(
|
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||||
LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
|
LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
|
||||||
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
|
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
|
||||||
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
|
Shape shape = ShapeUtil::FilterDimensions(
|
||||||
|
[&](int64 dim) { return !absl::c_linear_search(dims, dim); },
|
||||||
|
hlo->shape());
|
||||||
return computation_->AddInstruction(HloInstruction::CreateReduce(
|
return computation_->AddInstruction(HloInstruction::CreateReduce(
|
||||||
shape, hlo, zero, {dim}, AddReduce_computation));
|
shape, hlo, zero, dims, AddReduce_computation));
|
||||||
|
}
|
||||||
|
|
||||||
|
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
|
||||||
|
return AddReduce(hlo, std::vector<int64>{dim});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convenience method for replacing an instruction with a bitcast. If operand
|
// Convenience method for replacing an instruction with a bitcast. If operand
|
||||||
@ -1120,16 +1156,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
|
|||||||
std::swap(rhs_collapsing_dim, rhs_kept_dim);
|
std::swap(rhs_collapsing_dim, rhs_kept_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) {
|
|
||||||
if (hlo->shape().element_type() == element_type) {
|
|
||||||
return hlo;
|
|
||||||
}
|
|
||||||
return computation_->AddInstruction(HloInstruction::CreateConvert(
|
|
||||||
ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo));
|
|
||||||
};
|
|
||||||
|
|
||||||
auto reshape_if_necessary = [&](HloInstruction* hlo) {
|
auto reshape_if_necessary = [&](HloInstruction* hlo) {
|
||||||
hlo = as_type(hlo, dot->shape().element_type());
|
hlo = AsType(hlo, dot->shape().element_type());
|
||||||
if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
|
if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
|
||||||
hlo = computation_->AddInstruction(
|
hlo = computation_->AddInstruction(
|
||||||
HloInstruction::CreateReshape(dot->shape(), hlo));
|
HloInstruction::CreateReshape(dot->shape(), hlo));
|
||||||
@ -1138,7 +1166,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) {
|
auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) {
|
||||||
return AddReduce(as_type(hlo, F32), dim);
|
return AddReduce(AsType(hlo, F32), dim);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto broadcast = [&](HloInstruction* hlo, const Shape& shape,
|
auto broadcast = [&](HloInstruction* hlo, const Shape& shape,
|
||||||
@ -1247,8 +1275,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
|
|||||||
return dims;
|
return dims;
|
||||||
};
|
};
|
||||||
|
|
||||||
// If the contracting dimension is 1, remove the degnerate dimnesions from the
|
// If the contracting dimension is 1, remove the degnerate dimnensions from
|
||||||
// lhs and rhs, broadcast each to the result shape and multiply.
|
// the lhs and rhs, broadcast each to the result shape and multiply.
|
||||||
if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 &&
|
if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 &&
|
||||||
(rhs_kept_dim == rhs_rank - 1 ||
|
(rhs_kept_dim == rhs_rank - 1 ||
|
||||||
(rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) {
|
(rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) {
|
||||||
@ -1608,34 +1636,26 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
|
|||||||
// If there are no contracting dimensions, a dot can be rewritten as
|
// If there are no contracting dimensions, a dot can be rewritten as
|
||||||
// mul(broadcast(transpose(x)),broadcast(transpose(y)))
|
// mul(broadcast(transpose(x)),broadcast(transpose(y)))
|
||||||
if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) {
|
if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) {
|
||||||
std::vector<int64> lhs_transpose(
|
TF_ASSIGN_OR_RETURN(
|
||||||
dot->dot_dimension_numbers().lhs_batch_dimensions().begin(),
|
HloInstruction * new_lhs,
|
||||||
dot->dot_dimension_numbers().lhs_batch_dimensions().end());
|
NormalizeDotOperandToBatchMajorAndContractingMinor(
|
||||||
for (int64 i = 0; i < lhs->shape().rank(); ++i) {
|
lhs,
|
||||||
if (!absl::c_linear_search(
|
AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
|
||||||
dot->dot_dimension_numbers().lhs_batch_dimensions(), i)) {
|
AsInt64Slice(
|
||||||
lhs_transpose.push_back(i);
|
dot->dot_dimension_numbers().lhs_contracting_dimensions())));
|
||||||
}
|
|
||||||
}
|
|
||||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs,
|
|
||||||
MakeTransposeHlo(lhs, lhs_transpose));
|
|
||||||
if (dot->shape().rank() != lhs->shape().rank()) {
|
if (dot->shape().rank() != lhs->shape().rank()) {
|
||||||
std::vector<int64> lhs_broadcast_dims(lhs->shape().rank());
|
std::vector<int64> lhs_broadcast_dims(lhs->shape().rank());
|
||||||
absl::c_iota(lhs_broadcast_dims, 0);
|
absl::c_iota(lhs_broadcast_dims, 0);
|
||||||
new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||||
dot->shape(), new_lhs, lhs_broadcast_dims));
|
dot->shape(), new_lhs, lhs_broadcast_dims));
|
||||||
}
|
}
|
||||||
std::vector<int64> rhs_transpose(
|
TF_ASSIGN_OR_RETURN(
|
||||||
dot->dot_dimension_numbers().rhs_batch_dimensions().begin(),
|
HloInstruction * new_rhs,
|
||||||
dot->dot_dimension_numbers().rhs_batch_dimensions().end());
|
NormalizeDotOperandToBatchMajorAndContractingMinor(
|
||||||
for (int64 i = 0; i < rhs->shape().rank(); ++i) {
|
rhs,
|
||||||
if (!absl::c_linear_search(
|
AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
|
||||||
dot->dot_dimension_numbers().rhs_batch_dimensions(), i)) {
|
AsInt64Slice(
|
||||||
rhs_transpose.push_back(i);
|
dot->dot_dimension_numbers().rhs_contracting_dimensions())));
|
||||||
}
|
|
||||||
}
|
|
||||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs,
|
|
||||||
MakeTransposeHlo(rhs, rhs_transpose));
|
|
||||||
if (dot->shape().rank() != rhs->shape().rank()) {
|
if (dot->shape().rank() != rhs->shape().rank()) {
|
||||||
std::vector<int64> rhs_broadcast_dims(
|
std::vector<int64> rhs_broadcast_dims(
|
||||||
dot->dot_dimension_numbers().lhs_batch_dimensions_size());
|
dot->dot_dimension_numbers().lhs_batch_dimensions_size());
|
||||||
@ -1651,6 +1671,78 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
|
|||||||
new_lhs, new_rhs));
|
new_lhs, new_rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the lhs or rhs have only batch and contracting dimensions, a dot can be
|
||||||
|
// rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
|
||||||
|
if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
|
||||||
|
dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
|
||||||
|
lhs->shape().rank()) ||
|
||||||
|
(dot->dot_dimension_numbers().rhs_contracting_dimensions_size() +
|
||||||
|
dot->dot_dimension_numbers().rhs_batch_dimensions_size() ==
|
||||||
|
rhs->shape().rank())) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
HloInstruction * new_lhs,
|
||||||
|
NormalizeDotOperandToBatchMajorAndContractingMinor(
|
||||||
|
lhs,
|
||||||
|
AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
|
||||||
|
AsInt64Slice(
|
||||||
|
dot->dot_dimension_numbers().lhs_contracting_dimensions())));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
HloInstruction * new_rhs,
|
||||||
|
NormalizeDotOperandToBatchMajorAndContractingMinor(
|
||||||
|
rhs,
|
||||||
|
AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
|
||||||
|
AsInt64Slice(
|
||||||
|
dot->dot_dimension_numbers().rhs_contracting_dimensions())));
|
||||||
|
|
||||||
|
int64 lhs_outer_dims =
|
||||||
|
lhs->shape().rank() -
|
||||||
|
(dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
|
||||||
|
dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
|
||||||
|
int64 rhs_outer_dims =
|
||||||
|
rhs->shape().rank() -
|
||||||
|
(dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
|
||||||
|
dot->dot_dimension_numbers().rhs_contracting_dimensions_size());
|
||||||
|
CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0);
|
||||||
|
if (rhs_outer_dims > 0) {
|
||||||
|
std::vector<int64> lhs_broadcast_dims(
|
||||||
|
dot->dot_dimension_numbers().lhs_batch_dimensions_size());
|
||||||
|
absl::c_iota(lhs_broadcast_dims, 0);
|
||||||
|
lhs_broadcast_dims.resize(lhs->shape().rank());
|
||||||
|
std::iota(lhs_broadcast_dims.begin() +
|
||||||
|
dot->dot_dimension_numbers().lhs_batch_dimensions_size(),
|
||||||
|
lhs_broadcast_dims.end(),
|
||||||
|
dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
|
||||||
|
rhs_outer_dims);
|
||||||
|
new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||||
|
new_rhs->shape(), new_lhs, lhs_broadcast_dims));
|
||||||
|
} else if (lhs_outer_dims > 0) {
|
||||||
|
std::vector<int64> rhs_broadcast_dims(
|
||||||
|
dot->dot_dimension_numbers().rhs_batch_dimensions_size());
|
||||||
|
absl::c_iota(rhs_broadcast_dims, 0);
|
||||||
|
rhs_broadcast_dims.resize(rhs->shape().rank());
|
||||||
|
std::iota(rhs_broadcast_dims.begin() +
|
||||||
|
dot->dot_dimension_numbers().rhs_batch_dimensions_size(),
|
||||||
|
rhs_broadcast_dims.end(),
|
||||||
|
dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
|
||||||
|
lhs_outer_dims);
|
||||||
|
new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||||
|
new_lhs->shape(), new_rhs, rhs_broadcast_dims));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
|
||||||
|
MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs));
|
||||||
|
std::vector<int64> reduce_dims(
|
||||||
|
dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
|
||||||
|
new_dot = AsType(new_dot, F32);
|
||||||
|
const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims);
|
||||||
|
absl::c_iota(
|
||||||
|
reduce_dims,
|
||||||
|
outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size());
|
||||||
|
new_dot = AddReduce(new_dot, reduce_dims);
|
||||||
|
new_dot = AsType(new_dot, dot->shape().element_type());
|
||||||
|
return ReplaceInstruction(dot, new_dot);
|
||||||
|
}
|
||||||
|
|
||||||
if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 ||
|
if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 ||
|
||||||
dot->shape().rank() > 2) {
|
dot->shape().rank() > 2) {
|
||||||
if (options_.enable_dot_strength_reduction() &&
|
if (options_.enable_dot_strength_reduction() &&
|
||||||
|
@ -2753,8 +2753,9 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) {
|
|||||||
Shape keys_shape = ShapeUtil::MakeShape(F32, {1});
|
Shape keys_shape = ShapeUtil::MakeShape(F32, {1});
|
||||||
auto keys = builder.AddInstruction(
|
auto keys = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, keys_shape, "keys"));
|
HloInstruction::CreateParameter(0, keys_shape, "keys"));
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, &builder,
|
||||||
MakeSortHlo(keys_shape, {keys}, 0, &builder, module.get()).status());
|
module.get())
|
||||||
|
.status());
|
||||||
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
||||||
AlgebraicSimplifier simplifier(default_options_);
|
AlgebraicSimplifier simplifier(default_options_);
|
||||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
@ -2775,7 +2776,8 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
|
|||||||
HloInstruction::CreateParameter(2, values_shape, "values1"));
|
HloInstruction::CreateParameter(2, values_shape, "values1"));
|
||||||
TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape(
|
TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape(
|
||||||
{keys_shape, values_shape, values_shape}),
|
{keys_shape, values_shape, values_shape}),
|
||||||
{keys, values0, values1}, 0, &builder, module.get())
|
{keys, values0, values1}, 0, /*is_stable=*/false,
|
||||||
|
&builder, module.get())
|
||||||
.status());
|
.status());
|
||||||
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
||||||
AlgebraicSimplifier simplifier(default_options_);
|
AlgebraicSimplifier simplifier(default_options_);
|
||||||
@ -3712,8 +3714,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
|
|||||||
HloInstruction* y =
|
HloInstruction* y =
|
||||||
builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
|
builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
|
||||||
DotDimensionNumbers dot_dnums;
|
DotDimensionNumbers dot_dnums;
|
||||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
dot_dnums.add_lhs_batch_dimensions(0);
|
||||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
dot_dnums.add_rhs_batch_dimensions(0);
|
||||||
builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums,
|
builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums,
|
||||||
DefaultPrecisionConfig(2)));
|
DefaultPrecisionConfig(2)));
|
||||||
std::unique_ptr<HloComputation> dot_computation(builder.Build());
|
std::unique_ptr<HloComputation> dot_computation(builder.Build());
|
||||||
@ -4220,12 +4222,24 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
|
|||||||
int m, k, n;
|
int m, k, n;
|
||||||
PrimitiveType element_type;
|
PrimitiveType element_type;
|
||||||
std::tie(m, k, n, element_type) = GetParam();
|
std::tie(m, k, n, element_type) = GetParam();
|
||||||
|
std::vector<int64> lhs_dims = {1, 3, 5};
|
||||||
Shape dot_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, n});
|
std::vector<int64> rhs_dims = lhs_dims;
|
||||||
Shape lhs_shape = k > 0 ? ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k})
|
std::vector<int64> output_dims = lhs_dims;
|
||||||
: ShapeUtil::MakeShape(element_type, {1, 3, 5, m});
|
if (m > 0) {
|
||||||
Shape rhs_shape = k > 0 ? ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n})
|
lhs_dims.push_back(m);
|
||||||
: ShapeUtil::MakeShape(element_type, {1, 3, 5, n});
|
output_dims.push_back(m);
|
||||||
|
}
|
||||||
|
if (k > 0) {
|
||||||
|
lhs_dims.push_back(k);
|
||||||
|
rhs_dims.push_back(k);
|
||||||
|
}
|
||||||
|
if (n > 0) {
|
||||||
|
rhs_dims.push_back(n);
|
||||||
|
output_dims.push_back(n);
|
||||||
|
}
|
||||||
|
Shape dot_shape = ShapeUtil::MakeShape(element_type, output_dims);
|
||||||
|
Shape lhs_shape = ShapeUtil::MakeShape(element_type, lhs_dims);
|
||||||
|
Shape rhs_shape = ShapeUtil::MakeShape(element_type, rhs_dims);
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
|
|
||||||
auto lhs = builder.AddInstruction(
|
auto lhs = builder.AddInstruction(
|
||||||
@ -4240,7 +4254,7 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
|
|||||||
dot_dnums.add_rhs_batch_dimensions(1);
|
dot_dnums.add_rhs_batch_dimensions(1);
|
||||||
dot_dnums.add_rhs_batch_dimensions(2);
|
dot_dnums.add_rhs_batch_dimensions(2);
|
||||||
if (k > 0) {
|
if (k > 0) {
|
||||||
dot_dnums.add_lhs_contracting_dimensions(4);
|
dot_dnums.add_lhs_contracting_dimensions(m > 0 ? 4 : 3);
|
||||||
dot_dnums.add_rhs_contracting_dimensions(3);
|
dot_dnums.add_rhs_contracting_dimensions(3);
|
||||||
}
|
}
|
||||||
builder.AddInstruction(HloInstruction::CreateDot(
|
builder.AddInstruction(HloInstruction::CreateDot(
|
||||||
@ -4248,9 +4262,9 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
|
|||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
AlgebraicSimplifier simplifier(default_options_);
|
AlgebraicSimplifier simplifier(default_options_);
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
|
TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
|
||||||
const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1 || k == -1;
|
const bool dot_should_be_transformed =
|
||||||
const bool computation_should_be_modified = dot_should_be_transformed;
|
m == 1 || k == 1 || n == 1 || m == -1 || k == -1 || n == -1;
|
||||||
EXPECT_EQ(changed, computation_should_be_modified);
|
EXPECT_EQ(changed, dot_should_be_transformed);
|
||||||
bool has_no_dot = true;
|
bool has_no_dot = true;
|
||||||
for (const auto& hlo : computation->instructions()) {
|
for (const auto& hlo : computation->instructions()) {
|
||||||
if (hlo->opcode() == HloOpcode::kDot) {
|
if (hlo->opcode() == HloOpcode::kDot) {
|
||||||
@ -4261,10 +4275,12 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
|
|||||||
EXPECT_EQ(has_no_dot, dot_should_be_transformed);
|
EXPECT_EQ(has_no_dot, dot_should_be_transformed);
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(BatchDotStrengthReductionTestInstantiation,
|
||||||
BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest,
|
BatchDotStrengthReductionTest,
|
||||||
::testing::Combine(::testing::Values(1, 2), ::testing::Values(-1, 1, 2),
|
::testing::Combine(::testing::Values(-1, 1, 2),
|
||||||
::testing::Values(1, 2), ::testing::Values(F32, BF16)));
|
::testing::Values(-1, 1, 2),
|
||||||
|
::testing::Values(-1, 1, 2),
|
||||||
|
::testing::Values(F32, BF16)));
|
||||||
|
|
||||||
class DotStrengthReductionTest
|
class DotStrengthReductionTest
|
||||||
: public AlgebraicSimplifierTest,
|
: public AlgebraicSimplifierTest,
|
||||||
|
@ -32,15 +32,13 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
namespace m = match;
|
namespace m = match;
|
||||||
|
|
||||||
// Checks if the argument instruction is an AllReduce, followed by a certain
|
// Checks if the argument instruction is an AllReduce, followed by a certain
|
||||||
// sequence of instructions and then a CRS. It must be possible to move
|
// sequence of instructions and then a CRS. It must be possible to move
|
||||||
// the AR past each instruction in the sequence. Returns the CRS, which is the
|
// the AR past each instruction in the sequence. Returns the CRS, which is the
|
||||||
// last instruction in the sequence.
|
// last instruction in the sequence.
|
||||||
absl::optional<HloInstruction*> MatchesArCrsPattern(
|
absl::optional<ArCrsCombiner::ArCrsPair> ArCrsCombiner::MatchesArCrsPattern(
|
||||||
HloInstruction* instruction) {
|
HloInstruction* instruction) {
|
||||||
auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool {
|
auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool {
|
||||||
if (instruction->user_count() != 1) {
|
if (instruction->user_count() != 1) {
|
||||||
@ -77,23 +75,23 @@ absl::optional<HloInstruction*> MatchesArCrsPattern(
|
|||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
auto next = instruction->users()[0];
|
auto next = instruction->users()[0];
|
||||||
|
int64 distance = 1;
|
||||||
while (!next->IsCrossReplicaAllReduce()) {
|
while (!next->IsCrossReplicaAllReduce()) {
|
||||||
if (can_ar_move_past_instruction(next)) {
|
if (can_ar_move_past_instruction(next)) {
|
||||||
next = next->users()[0];
|
next = next->users()[0];
|
||||||
} else {
|
} else {
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
++distance;
|
||||||
}
|
}
|
||||||
if (!Cast<HloAllReduceInstruction>(next)->IsNoop() &&
|
if (!Cast<HloAllReduceInstruction>(next)->IsNoop() &&
|
||||||
computation_is_addition(next->called_computations()[0])) {
|
computation_is_addition(next->called_computations()[0])) {
|
||||||
return absl::optional<HloInstruction*>(next);
|
return absl::optional<ArCrsPair>(ArCrsPair(instruction, next, distance));
|
||||||
} else {
|
} else {
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
absl::optional<HloInstruction*> ArCrsCombiner::WhileFromBodyParameter(
|
absl::optional<HloInstruction*> ArCrsCombiner::WhileFromBodyParameter(
|
||||||
HloInstruction* instruction) {
|
HloInstruction* instruction) {
|
||||||
CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
|
CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
|
||||||
@ -235,15 +233,55 @@ bool ArCrsCombiner::InstructionsComputeSameValue(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
|
void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
|
||||||
|
// Say that two or more ARs lead to the same CRS: (AR1, CRS), (AR2, CRS),
|
||||||
|
// ... , (ARn, CRS).
|
||||||
|
// If as we traverse the HLO graph we start tracking the pair (AR2, CRS),
|
||||||
|
// and later find that AR1's distance from the CRS is longer, we discard
|
||||||
|
// AR2 and start tracking AR1. We put the discarded ids in this set, in order
|
||||||
|
// to skip processing of short paths when we encounter the other ARs that
|
||||||
|
// have the same id as AR2.
|
||||||
|
absl::flat_hash_set<int64> discarded_ar_ids;
|
||||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||||
for (HloInstruction* instruction : computation->instructions()) {
|
for (HloInstruction* instruction : computation->instructions()) {
|
||||||
auto maybe_crs = MatchesArCrsPattern(instruction);
|
auto maybe_pair = MatchesArCrsPattern(instruction);
|
||||||
if (maybe_crs) {
|
if (maybe_pair) {
|
||||||
auto crs = *maybe_crs;
|
auto pair = *maybe_pair;
|
||||||
int64 ar_id = *(instruction->all_reduce_id());
|
int64 ar_id = *(instruction->all_reduce_id());
|
||||||
if (crs_reserved_map_.find(crs) == crs_reserved_map_.end()) {
|
if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) {
|
||||||
all_reduce_map_[ar_id].push_back(instruction);
|
continue;
|
||||||
crs_reserved_map_[crs] = ar_id;
|
}
|
||||||
|
auto it = crs_reserved_map_.find(pair.crs);
|
||||||
|
if (it != crs_reserved_map_.end()) {
|
||||||
|
auto prev_ar_id = it->second;
|
||||||
|
// Since there is another AR paired with CRS,
|
||||||
|
// all_reduce_map_[prev_ar_id] should exist, but
|
||||||
|
// all_reduce_map_[ar_id] shouldn't.
|
||||||
|
CHECK(all_reduce_map_.find(ar_id) == all_reduce_map_.end());
|
||||||
|
CHECK_NE(prev_ar_id, ar_id);
|
||||||
|
auto prev_pair = all_reduce_map_[prev_ar_id].back();
|
||||||
|
int64 prev_distance = prev_pair.distance;
|
||||||
|
if (prev_distance < pair.distance) {
|
||||||
|
// The current AR's distance to CRS is longer than the previously
|
||||||
|
// tracked AR, so we discard the previous AR.
|
||||||
|
all_reduce_map_.erase(prev_ar_id);
|
||||||
|
discarded_ar_ids.insert(prev_ar_id);
|
||||||
|
all_reduce_map_[ar_id].push_back(pair);
|
||||||
|
crs_reserved_map_[pair.crs] = ar_id;
|
||||||
|
} else {
|
||||||
|
// Discard the current AR id because we are keeping the previously
|
||||||
|
// tracked AR.
|
||||||
|
discarded_ar_ids.insert(ar_id);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (all_reduce_map_.find(ar_id) != all_reduce_map_.end()) {
|
||||||
|
int64 prev_distance = all_reduce_map_[ar_id].back().distance;
|
||||||
|
CHECK_EQ(prev_distance, pair.distance)
|
||||||
|
<< "All ARs with the same AR ID must have the same distance "
|
||||||
|
"from the corresponding CRSs. Found: "
|
||||||
|
<< prev_distance << " and " << pair.distance;
|
||||||
|
}
|
||||||
|
all_reduce_map_[ar_id].push_back(pair);
|
||||||
|
crs_reserved_map_[pair.crs] = ar_id;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -253,11 +291,11 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
|
|||||||
void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
|
void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
|
||||||
for (auto it : all_reduce_map_) {
|
for (auto it : all_reduce_map_) {
|
||||||
auto all_reduce_id = it.first;
|
auto all_reduce_id = it.first;
|
||||||
auto instruction_vec = it.second;
|
auto pairs_vec = it.second;
|
||||||
CHECK_EQ(instruction_vec.size(), num_spatial_partitions_);
|
CHECK_EQ(pairs_vec.size(), num_spatial_partitions_);
|
||||||
auto instr_0 = instruction_vec[0];
|
auto instr_0 = pairs_vec[0].ar;
|
||||||
for (int i = 1; i < instruction_vec.size(); ++i) {
|
for (int i = 1; i < pairs_vec.size(); ++i) {
|
||||||
auto instr_i = instruction_vec[i];
|
auto instr_i = pairs_vec[i].ar;
|
||||||
auto next_0 = instr_0->users()[0];
|
auto next_0 = instr_0->users()[0];
|
||||||
auto next_i = instr_i->users()[0];
|
auto next_i = instr_i->users()[0];
|
||||||
absl::flat_hash_map<int64, int64> visited_pairs;
|
absl::flat_hash_map<int64, int64> visited_pairs;
|
||||||
@ -281,8 +319,9 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
for (auto it : all_reduce_map_) {
|
for (auto it : all_reduce_map_) {
|
||||||
auto instruction_vec = it.second;
|
auto pairs_vec = it.second;
|
||||||
for (auto all_reduce : instruction_vec) {
|
for (auto pair : pairs_vec) {
|
||||||
|
auto all_reduce = pair.ar;
|
||||||
auto parent_computation = all_reduce->parent();
|
auto parent_computation = all_reduce->parent();
|
||||||
auto all_reduce_id = all_reduce->all_reduce_id();
|
auto all_reduce_id = all_reduce->all_reduce_id();
|
||||||
auto prev = all_reduce->mutable_operand(0);
|
auto prev = all_reduce->mutable_operand(0);
|
||||||
@ -303,16 +342,23 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
|
|||||||
? next->operands()[1]
|
? next->operands()[1]
|
||||||
: next->operands()[0];
|
: next->operands()[0];
|
||||||
// To move the AR past the addition/subtraction, we need to divide
|
// To move the AR past the addition/subtraction, we need to divide
|
||||||
// other_operand by the number of spatial partitions.
|
// other_operand by the number of spatial partitions, except if
|
||||||
auto shape = other_operand->shape();
|
// other_operand is a cross-module AR, which can be eliminated.
|
||||||
Literal lit(shape);
|
if (other_operand->IsCrossModuleAllReduce() &&
|
||||||
lit.PopulateWithValue<float>(num_spatial_partitions_);
|
other_operand->user_count() == 1) {
|
||||||
auto divisor = parent_computation->AddInstruction(
|
TF_CHECK_OK(other_operand->ReplaceAllUsesWith(
|
||||||
HloInstruction::CreateConstant(lit.Clone()));
|
other_operand->mutable_operand(0)));
|
||||||
auto division =
|
} else {
|
||||||
parent_computation->AddInstruction(HloInstruction::CreateBinary(
|
auto shape = other_operand->shape();
|
||||||
shape, HloOpcode::kDivide, other_operand, divisor));
|
Literal lit(shape);
|
||||||
TF_CHECK_OK(other_operand->ReplaceUseWith(next, division));
|
lit.PopulateWithValue<float>(num_spatial_partitions_);
|
||||||
|
auto divisor = parent_computation->AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(lit.Clone()));
|
||||||
|
auto division = parent_computation->AddInstruction(
|
||||||
|
HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
|
||||||
|
other_operand, divisor));
|
||||||
|
TF_CHECK_OK(other_operand->ReplaceUseWith(next, division));
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -26,11 +26,47 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// When the HLO graph contains a cross-module AllReduce, followed by some simple
|
// When the HLO graph contains a cross-module AllReduce, followed by some simple
|
||||||
// linear operations, followed by a cross-replica AllReduce, we can combine the
|
// linear operations, followed by a cross-replica AllReduce (also known as
|
||||||
// CMAR and the CRAR, to use an efficient AllReduce implementation that fully
|
// cross-replica sum, or CRS), we can combine the CMAR and the CRAR, to use an
|
||||||
// utilizes the interconnect bandwidth.
|
// efficient AllReduce implementation that fully utilizes the interconnect
|
||||||
|
// bandwidth.
|
||||||
// Such sequences appear in spatially partitioned models.
|
// Such sequences appear in spatially partitioned models.
|
||||||
// This pass must run right after spatial partitioning.
|
// This pass must run right after spatial partitioning, when the code is still
|
||||||
|
// in a single HLO module.
|
||||||
|
//
|
||||||
|
// The steps are:
|
||||||
|
// 1) Find CMARs followed by simple ops followed by CRARs.
|
||||||
|
// 2) Group CMARs by all_reduce_id. They must all be rewritten.
|
||||||
|
// 3) Prove that the CMAR patterns in each core produce the same result.
|
||||||
|
// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the
|
||||||
|
// other operand by the number of spatial partitions.
|
||||||
|
// 5) Turn the CRAR into an all-core AllReduce.
|
||||||
|
//
|
||||||
|
// The pass also handles the case where multiple CMARs lead to the same CRAR,
|
||||||
|
// and eliminates all CMARs. This graph:
|
||||||
|
//
|
||||||
|
// Y
|
||||||
|
// |
|
||||||
|
// X CMAR_2 Z
|
||||||
|
// | \ /
|
||||||
|
// CMAR_1 +
|
||||||
|
// \ /
|
||||||
|
// +
|
||||||
|
// |
|
||||||
|
// CRAR
|
||||||
|
//
|
||||||
|
// gets rewritten to:
|
||||||
|
//
|
||||||
|
// Z num_partitions
|
||||||
|
// \ /
|
||||||
|
// Y div
|
||||||
|
// \ /
|
||||||
|
// X +
|
||||||
|
// \ /
|
||||||
|
// +
|
||||||
|
// |
|
||||||
|
// all-core AR
|
||||||
|
//
|
||||||
class ArCrsCombiner : public HloModulePass {
|
class ArCrsCombiner : public HloModulePass {
|
||||||
public:
|
public:
|
||||||
ArCrsCombiner(int num_spatial_partitions)
|
ArCrsCombiner(int num_spatial_partitions)
|
||||||
@ -43,6 +79,28 @@ class ArCrsCombiner : public HloModulePass {
|
|||||||
HloInstruction* i2);
|
HloInstruction* i2);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// We used this struct because multiple ARs could be paired with the same CRS.
|
||||||
|
// In this case, we want to select the AR that is furthest from the CRS,
|
||||||
|
// because it makes it easier to eliminate all ARs during RewriteGraph.
|
||||||
|
struct ArCrsPair {
|
||||||
|
HloInstruction* ar;
|
||||||
|
HloInstruction* crs;
|
||||||
|
// The length of the path from AR to CRS in the HLO graph.
|
||||||
|
int64 distance;
|
||||||
|
|
||||||
|
ArCrsPair(HloInstruction* all_reduce, HloInstruction* cross_replica_sum,
|
||||||
|
int64 dist)
|
||||||
|
: ar(all_reduce), crs(cross_replica_sum), distance(dist) {}
|
||||||
|
|
||||||
|
string ToString() {
|
||||||
|
return absl::StrCat("(AR: ", ar->name(), ", CRS: ", crs->name(),
|
||||||
|
", distance: ", distance, ")");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
absl::optional<ArCrsCombiner::ArCrsPair> MatchesArCrsPattern(
|
||||||
|
HloInstruction* instruction);
|
||||||
|
|
||||||
// If the passed instruction is a while parameter, and the while body is only
|
// If the passed instruction is a while parameter, and the while body is only
|
||||||
// called by a single while instruction, return the while instruction.
|
// called by a single while instruction, return the while instruction.
|
||||||
absl::optional<HloInstruction*> WhileFromBodyParameter(
|
absl::optional<HloInstruction*> WhileFromBodyParameter(
|
||||||
@ -80,8 +138,8 @@ class ArCrsCombiner : public HloModulePass {
|
|||||||
|
|
||||||
int num_spatial_partitions_;
|
int num_spatial_partitions_;
|
||||||
|
|
||||||
// Map from all-reduce ids to the all reduce instructions.
|
// Map from all-reduce ids to the AR/CRS pairs.
|
||||||
absl::flat_hash_map<int64, std::vector<HloInstruction*>> all_reduce_map_;
|
absl::flat_hash_map<int64, std::vector<ArCrsPair>> all_reduce_map_;
|
||||||
|
|
||||||
// Map from a CRS instruction to the all-reduce ID of the AR paired with the
|
// Map from a CRS instruction to the all-reduce ID of the AR paired with the
|
||||||
// CRS. Sometimes, several ARs in the code could be paired with the same CRS.
|
// CRS. Sometimes, several ARs in the code could be paired with the same CRS.
|
||||||
|
@ -1005,11 +1005,11 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
op::Tuple(op::AllReduce(op::Add(
|
op::Tuple(op::AllReduce(op::Add(
|
||||||
op::Add(op::Parameter(),
|
op::Add(op::Parameter(),
|
||||||
op::Divide(op::Constant(), op::Constant())),
|
op::Divide(op::Constant(), op::Constant())),
|
||||||
op::Divide(op::AllReduce(), op::Constant()))),
|
op::Parameter())),
|
||||||
op::AllReduce(op::Add(
|
op::AllReduce(op::Add(
|
||||||
op::Add(op::Parameter(),
|
op::Add(op::Parameter(),
|
||||||
op::Divide(op::Constant(), op::Constant())),
|
op::Divide(op::Constant(), op::Constant())),
|
||||||
op::Divide(op::AllReduce(), op::Constant())))));
|
op::Parameter()))));
|
||||||
auto crs_after =
|
auto crs_after =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_after = crs_after->replica_groups();
|
auto replica_groups_after = crs_after->replica_groups();
|
||||||
@ -1093,15 +1093,17 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
ArCrsCombiner combiner(2);
|
ArCrsCombiner combiner(2);
|
||||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||||
EXPECT_TRUE(changed);
|
EXPECT_TRUE(changed);
|
||||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
EXPECT_THAT(
|
||||||
op::Tuple(op::AllReduce(op::Add(
|
module->entry_computation()->root_instruction(),
|
||||||
op::Parameter(),
|
op::Tuple(op::AllReduce(op::Add(
|
||||||
op::Divide(op::Add(op::AllReduce(), op::Constant()),
|
op::Parameter(),
|
||||||
op::Constant()))),
|
op::Add(op::Parameter(),
|
||||||
op::AllReduce(op::Add(
|
op::Divide(op::Constant(), op::Constant())))),
|
||||||
op::Parameter(),
|
op::AllReduce(op::Add(
|
||||||
op::Divide(op::Add(op::AllReduce(), op::Constant()),
|
op::Parameter(),
|
||||||
op::Constant())))));
|
op::Add(op::Parameter(),
|
||||||
|
op::Divide(op::Constant(), op::Constant()))))));
|
||||||
|
|
||||||
auto crs_after =
|
auto crs_after =
|
||||||
module->entry_computation()->root_instruction()->operands()[0];
|
module->entry_computation()->root_instruction()->operands()[0];
|
||||||
auto replica_groups_after = crs_after->replica_groups();
|
auto replica_groups_after = crs_after->replica_groups();
|
||||||
|
@ -286,7 +286,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
|
|||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto* sort,
|
auto* sort,
|
||||||
MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}),
|
MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}),
|
||||||
{key, value}, 0, &builder, module.get()));
|
{key, value}, 0, /*is_stable=*/false, &builder,
|
||||||
|
module.get()));
|
||||||
HloInstruction* gte = builder.AddInstruction(
|
HloInstruction* gte = builder.AddInstruction(
|
||||||
HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0));
|
HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0));
|
||||||
|
|
||||||
@ -314,7 +315,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) {
|
|||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto* sort,
|
auto* sort,
|
||||||
MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}),
|
MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}),
|
||||||
{key, value}, 0, &builder, module.get()));
|
{key, value}, 0, /*is_stable=*/false, &builder,
|
||||||
|
module.get()));
|
||||||
|
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
@ -673,9 +673,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
|||||||
if (embed_ir_in_executable) {
|
if (embed_ir_in_executable) {
|
||||||
ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
|
ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
|
|
||||||
|
|
||||||
XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module));
|
XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module));
|
||||||
|
TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
|
||||||
|
|
||||||
// JIT compile the LLVM IR module to in-memory machine code.
|
// JIT compile the LLVM IR module to in-memory machine code.
|
||||||
jit->AddModule(std::move(llvm_module));
|
jit->AddModule(std::move(llvm_module));
|
||||||
|
@ -963,8 +963,8 @@ Status EmitBatchDotOperation(
|
|||||||
KernelSupportLibrary ksl(b);
|
KernelSupportLibrary ksl(b);
|
||||||
|
|
||||||
return ksl.ForWithStatus(
|
return ksl.ForWithStatus(
|
||||||
"bdot", /*start=*/0, /*end=*/batch_count, /*step=*/1,
|
llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count,
|
||||||
[&](llvm::Value* indvar) {
|
/*step=*/1, [&](llvm::Value* indvar) {
|
||||||
DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
|
DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
|
||||||
adjusted_dim_numbers.clear_lhs_batch_dimensions();
|
adjusted_dim_numbers.clear_lhs_batch_dimensions();
|
||||||
adjusted_dim_numbers.clear_rhs_batch_dimensions();
|
adjusted_dim_numbers.clear_rhs_batch_dimensions();
|
||||||
|
@ -583,7 +583,7 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) {
|
|||||||
b_.getVoidTy(),
|
b_.getVoidTy(),
|
||||||
{b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
|
{b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
|
||||||
b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(),
|
b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(),
|
||||||
b_.getInt32Ty()->getPointerTo(), b_.getInt8PtrTy(),
|
b_.getInt32Ty()->getPointerTo(), b_.getInt1Ty(), b_.getInt8PtrTy(),
|
||||||
b_.getInt64Ty()->getPointerTo(), less_than_function->getType()},
|
b_.getInt64Ty()->getPointerTo(), less_than_function->getType()},
|
||||||
/*isVarArg=*/false);
|
/*isVarArg=*/false);
|
||||||
auto* key_value_sort_func = llvm::dyn_cast<llvm::Function>(
|
auto* key_value_sort_func = llvm::dyn_cast<llvm::Function>(
|
||||||
@ -616,8 +616,8 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) {
|
|||||||
{b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
|
{b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
|
||||||
b_.getInt64(lower_dimensions), values,
|
b_.getInt64(lower_dimensions), values,
|
||||||
b_.getInt32(sort->operand_count()), sizes,
|
b_.getInt32(sort->operand_count()), sizes,
|
||||||
GetExecutableRunOptionsArgument(), GetProfileCountersArgument(),
|
b_.getInt1(sort->is_stable()), GetExecutableRunOptionsArgument(),
|
||||||
less_than_function});
|
GetProfileCountersArgument(), less_than_function});
|
||||||
|
|
||||||
if (sort->values_count() > 0) {
|
if (sort->values_count() > 0) {
|
||||||
llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_,
|
llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_,
|
||||||
|
@ -32,8 +32,8 @@ using tensorflow::int64;
|
|||||||
|
|
||||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort(
|
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort(
|
||||||
int64 a, int64 b, int64 c, char** values, int32 values_count,
|
int64 a, int64 b, int64 c, char** values, int32 values_count,
|
||||||
int32* values_primitive_type_size_in_bytes, char* run_options,
|
int32* values_primitive_type_size_in_bytes, bool is_stable,
|
||||||
int64* prof_counters,
|
char* run_options, int64* prof_counters,
|
||||||
void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)) {
|
void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)) {
|
||||||
// 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT
|
// 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT
|
||||||
// code, so msan can't tell they are initialized.
|
// code, so msan can't tell they are initialized.
|
||||||
@ -69,22 +69,27 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort(
|
|||||||
int64 base_offset =
|
int64 base_offset =
|
||||||
index % sort_dimension_offset +
|
index % sort_dimension_offset +
|
||||||
(index - index % sort_dimension_offset) * sort_dimension_elements;
|
(index - index % sort_dimension_offset) * sort_dimension_elements;
|
||||||
std::stable_sort(
|
auto compare_function = [&](int64 a, int64 b) -> bool {
|
||||||
indices.get(), indices.get() + sort_dimension_elements,
|
int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) *
|
||||||
[&](int64 a, int64 b) -> bool {
|
values_primitive_type_size_in_bytes[0];
|
||||||
int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) *
|
int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) *
|
||||||
values_primitive_type_size_in_bytes[0];
|
values_primitive_type_size_in_bytes[0];
|
||||||
int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) *
|
for (int32 i = 0; i < values_count; ++i) {
|
||||||
values_primitive_type_size_in_bytes[0];
|
comparison_values[i * 2] = values[i] + memory_index_lhs;
|
||||||
for (int32 i = 0; i < values_count; ++i) {
|
comparison_values[i * 2 + 1] = values[i] + memory_index_rhs;
|
||||||
comparison_values[i * 2] = values[i] + memory_index_lhs;
|
}
|
||||||
comparison_values[i * 2 + 1] = values[i] + memory_index_rhs;
|
char result = 0; // Overwritten by less_than.
|
||||||
}
|
less_than(&result, run_options, comparison_values.get(), nullptr,
|
||||||
char result = 0; // Overwritten by less_than.
|
prof_counters);
|
||||||
less_than(&result, run_options, comparison_values.get(), nullptr,
|
return result != 0u;
|
||||||
prof_counters);
|
};
|
||||||
return result != 0u;
|
if (is_stable) {
|
||||||
});
|
std::stable_sort(indices.get(), indices.get() + sort_dimension_elements,
|
||||||
|
compare_function);
|
||||||
|
} else {
|
||||||
|
std::sort(indices.get(), indices.get() + sort_dimension_elements,
|
||||||
|
compare_function);
|
||||||
|
}
|
||||||
|
|
||||||
// Reorder the values according to the order defined by 'indices'.
|
// Reorder the values according to the order defined by 'indices'.
|
||||||
for (int32 idx = 0; idx < values_count; ++idx) {
|
for (int32 idx = 0; idx < values_count; ++idx) {
|
||||||
|
@ -22,15 +22,14 @@ limitations under the License.
|
|||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
// Each entry in 'values' represents a 3-dimensional shape with dimensions
|
// Each entry in 'values' represents a 3-dimensional shape with dimensions
|
||||||
// [a, b, c]. The 'b' dimension of the first shape is sorted into ascending
|
// [a, b, c]. The 'b' dimension of each shape is sorted into ascending order
|
||||||
// order according to the results of comparisons using the provided 'less_than'
|
// according to the results of comparisons using the provided 'less_than'
|
||||||
// function. 'values_count' must be > 0 and specifies the number of entries in
|
// function. 'values_count' must be > 0 and specifies the number of entries in
|
||||||
// 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive
|
// 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive
|
||||||
// type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]'
|
// type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]'
|
||||||
// bytes. The elements in each 'values' shape are reordered in the same way
|
// bytes. 'is_stable' specifies whether the sorting should be stable.
|
||||||
// according to the comparisons using the first shape. 'run_options' and
|
// 'run_options' and 'prof_counters' are passed through to the less-than
|
||||||
// 'prof_counters' are passed through to the less-than function, which expects
|
// function, which expects the following arguments:
|
||||||
// the following arguments:
|
|
||||||
// - pointer to the return value buffer (char*)
|
// - pointer to the return value buffer (char*)
|
||||||
// - xla::ExecutableRunOptions = 'run_options' (char*)
|
// - xla::ExecutableRunOptions = 'run_options' (char*)
|
||||||
// - pointers to the parameter buffers (char**)
|
// - pointers to the parameter buffers (char**)
|
||||||
@ -39,8 +38,8 @@ extern "C" {
|
|||||||
extern void __xla_cpu_runtime_KeyValueSort(
|
extern void __xla_cpu_runtime_KeyValueSort(
|
||||||
tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
|
tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
|
||||||
char** values, tensorflow::int32 values_count,
|
char** values, tensorflow::int32 values_count,
|
||||||
tensorflow::int32* values_primitive_type_size_in_bytes, char* run_options,
|
tensorflow::int32* values_primitive_type_size_in_bytes, bool is_stable,
|
||||||
tensorflow::int64* prof_counters,
|
char* run_options, tensorflow::int64* prof_counters,
|
||||||
void (*less_than)(char*, char*, char**, char**, tensorflow::int64*));
|
void (*less_than)(char*, char*, char**, char**, tensorflow::int64*));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -938,6 +938,53 @@ void TiledSmallGemmEmitter::EmitTiledGemm(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llvm::Type* GetPointerToElementType(llvm::Type* pointer_type) {
|
||||||
|
llvm::Type* type =
|
||||||
|
llvm::cast<llvm::PointerType>(pointer_type)->getElementType();
|
||||||
|
while (auto* array_type = llvm::dyn_cast<llvm::ArrayType>(type)) {
|
||||||
|
type = array_type->getElementType();
|
||||||
|
}
|
||||||
|
|
||||||
|
return type->getPointerTo();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GemvBuffersWithCanonicalType {
|
||||||
|
llvm::Value* lhs_canonicalized;
|
||||||
|
llvm::Value* rhs_canonicalized;
|
||||||
|
llvm::Value* addend_canonicalized;
|
||||||
|
llvm::Value* result_canonicalized;
|
||||||
|
};
|
||||||
|
|
||||||
|
GemvBuffersWithCanonicalType GetGemvBuffersWithCanonicalType(
|
||||||
|
llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend,
|
||||||
|
llvm::Value* result, llvm::IRBuilder<>* b) {
|
||||||
|
// We characterize a GEMV operation via M and K, since N is implicitly 1.
|
||||||
|
// This means the GEMV that multiplies (say) [5,6] with [6,1] is implemented
|
||||||
|
// by the same GEMV that multiplies [5,6] with [1,6]. However, the
|
||||||
|
// `llvm::Types` for the inputs to the two GEMVs don't match (in a trivial
|
||||||
|
// sense -- the in memory representations are the same) since they're computed
|
||||||
|
// from the `xla::Shape`s. Since we want to be able to call the same
|
||||||
|
// `llvm::Function` for the two GEMVs we canonicalize the types of the GEMV
|
||||||
|
// inputs here into the same type.
|
||||||
|
GemvBuffersWithCanonicalType buffers_with_canonical_type;
|
||||||
|
llvm::Type* lhs_type = lhs->getType();
|
||||||
|
llvm::Type* rhs_type = rhs->getType();
|
||||||
|
llvm::Type* addend_type = addend ? addend->getType() : nullptr;
|
||||||
|
llvm::Type* result_type = result->getType();
|
||||||
|
|
||||||
|
buffers_with_canonical_type.lhs_canonicalized =
|
||||||
|
b->CreateBitCast(lhs, GetPointerToElementType(lhs_type));
|
||||||
|
buffers_with_canonical_type.rhs_canonicalized =
|
||||||
|
b->CreateBitCast(rhs, GetPointerToElementType(rhs_type));
|
||||||
|
buffers_with_canonical_type.addend_canonicalized =
|
||||||
|
addend ? b->CreateBitCast(addend, GetPointerToElementType(addend_type))
|
||||||
|
: nullptr;
|
||||||
|
buffers_with_canonical_type.result_canonicalized =
|
||||||
|
b->CreateBitCast(result, GetPointerToElementType(result_type));
|
||||||
|
|
||||||
|
return buffers_with_canonical_type;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
|
void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
|
||||||
@ -950,12 +997,18 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
|
|||||||
/*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
|
/*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
|
||||||
/*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
|
/*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
|
||||||
|
|
||||||
|
GemvBuffersWithCanonicalType canonical_inputs =
|
||||||
|
GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
|
||||||
|
|
||||||
KernelSupportLibrary::EmitAndCallOutlinedKernel(
|
KernelSupportLibrary::EmitAndCallOutlinedKernel(
|
||||||
/*enable_fast_math=*/enable_fast_math,
|
/*enable_fast_math=*/enable_fast_math,
|
||||||
/*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs,
|
/*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(),
|
||||||
rhs, addend, result,
|
canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
|
||||||
[&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend,
|
canonical_inputs.addend_canonicalized,
|
||||||
llvm::Value* result) {
|
canonical_inputs.result_canonicalized,
|
||||||
|
[&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs,
|
||||||
|
llvm::Value* addend,
|
||||||
|
llvm::Value* result) {
|
||||||
RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
|
RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
|
||||||
result, b);
|
result, b);
|
||||||
emitter.Emit();
|
emitter.Emit();
|
||||||
@ -972,12 +1025,18 @@ void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
|
|||||||
/*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
|
/*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
|
||||||
/*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
|
/*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
|
||||||
|
|
||||||
|
GemvBuffersWithCanonicalType canonical_inputs =
|
||||||
|
GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
|
||||||
|
|
||||||
KernelSupportLibrary::EmitAndCallOutlinedKernel(
|
KernelSupportLibrary::EmitAndCallOutlinedKernel(
|
||||||
/*enable_fast_math=*/enable_fast_math,
|
/*enable_fast_math=*/enable_fast_math,
|
||||||
/*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs,
|
/*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(),
|
||||||
rhs, addend, result,
|
canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
|
||||||
[&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend,
|
canonical_inputs.addend_canonicalized,
|
||||||
llvm::Value* result) {
|
canonical_inputs.result_canonicalized,
|
||||||
|
[&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs,
|
||||||
|
llvm::Value* addend,
|
||||||
|
llvm::Value* result) {
|
||||||
ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
|
ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
|
||||||
result, b);
|
result, b);
|
||||||
emitter.Emit();
|
emitter.Emit();
|
||||||
|
@ -1367,26 +1367,69 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
|
|||||||
llvm_ir::PrimitiveTypeToIrType(elem_prim_ty, module_);
|
llvm_ir::PrimitiveTypeToIrType(elem_prim_ty, module_);
|
||||||
llvm::Type* raw_value_ty = raw_value->getType();
|
llvm::Type* raw_value_ty = raw_value->getType();
|
||||||
|
|
||||||
// Convert raw integer to float in range [0, 1) if the element is a float.
|
// If we're generating a floating-point value, convert the raw integer R (i.e.
|
||||||
|
// `raw_value`) to a float in the range [0, 1).
|
||||||
|
//
|
||||||
|
// The basic approach is to choose a significand and exponent such that the
|
||||||
|
// significand is uniformly distributed and the exponent is distributed, well,
|
||||||
|
// exponentially (it's more likely to be close to 0 than far from 0).
|
||||||
|
//
|
||||||
|
// An easy way to do this is to say that the significand is the first S bits
|
||||||
|
// of R, and the exponent is determined by the number of trailing zeroes in R,
|
||||||
|
// exp = 2^-(cttz(R) + 1). (+1 because the largest exponent should be -1;
|
||||||
|
// this way the largest value we can return is 1.999... * 2^-1 = 1-ε.)
|
||||||
|
//
|
||||||
|
// This results in a small bias. Namely, if R has enough trailing zeroes, the
|
||||||
|
// significand and exponent will "overlap". As a concrete example, consider
|
||||||
|
//
|
||||||
|
// 20 X's 12 zeroes
|
||||||
|
// R = 0bXXXXXXXXXXXXXXXXXXXX000000000000
|
||||||
|
//
|
||||||
|
// Here the exponent is 2^-13 because R has 12 trailing zeroes. The
|
||||||
|
// significand is made up of the first 23 most-significant bits of R, which we
|
||||||
|
// observe contain 3 zeroes. This is biased because any random value with
|
||||||
|
// exponent 2^-12 will have a significand which ends in `000`.
|
||||||
|
//
|
||||||
|
// For f32s, this problem occurs only when there are more than 32-23 = 9
|
||||||
|
// trailing zeros, which happens with probability 0.5^10 = ~0.1%. Moreover the
|
||||||
|
// probability of a large bias (i.e. many trailing 0s in the significand) is
|
||||||
|
// exponentially low. So we deem this acceptable.
|
||||||
llvm::Value* elem_value = raw_value;
|
llvm::Value* elem_value = raw_value;
|
||||||
if (elem_ir_ty->isFloatingPointTy()) {
|
if (elem_ir_ty->isFloatingPointTy()) {
|
||||||
unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits();
|
const auto& dest_flt_semantics = elem_ir_ty->getFltSemantics();
|
||||||
CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64);
|
const int bits = raw_value_ty->getPrimitiveSizeInBits();
|
||||||
// Perform the division using the float type with the same number of bits
|
CHECK_GE(bits, llvm::APFloat::semanticsSizeInBits(dest_flt_semantics));
|
||||||
// as the raw value to avoid overflow.
|
|
||||||
if (raw_value_size_in_bits == 32) {
|
|
||||||
elem_value = UIToFP(elem_value, b_->getFloatTy());
|
|
||||||
elem_value = FDiv(elem_value,
|
|
||||||
llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32)));
|
|
||||||
} else {
|
|
||||||
elem_value = UIToFP(elem_value, b_->getDoubleTy());
|
|
||||||
elem_value = FDiv(
|
|
||||||
elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64)));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (elem_ir_ty != elem_value->getType()) {
|
// Subtract 1 because semanticsPrecision includes the "hidden bit", i.e. the
|
||||||
elem_value = FPTrunc(elem_value, elem_ir_ty);
|
// implicit "1." at the beginning of the significand.
|
||||||
}
|
const int significand_bits =
|
||||||
|
llvm::APFloat::semanticsPrecision(dest_flt_semantics) - 1;
|
||||||
|
|
||||||
|
llvm::Value* cttz = llvm_ir::EmitCallToIntrinsic(
|
||||||
|
llvm::Intrinsic::cttz, {raw_value, /*is_zero_undef=*/b_->getFalse()},
|
||||||
|
{raw_value->getType()}, b_);
|
||||||
|
llvm::Value* significand = LShr(raw_value, bits - significand_bits);
|
||||||
|
|
||||||
|
// Exponent bias is -127 for f32, meaning that if the exponent is E and the
|
||||||
|
// significand is S, then the value of the number is 2^(E - 127) * (1.S).
|
||||||
|
//
|
||||||
|
// We want cttz == 0 to correspond to 2^-1, so our exponent is computed as
|
||||||
|
// E = 126 - cttz.
|
||||||
|
//
|
||||||
|
// For f64, this is all the same, except the bias is -1023.
|
||||||
|
//
|
||||||
|
// In IEEE floating point, the absolute value of the exponent bias equals
|
||||||
|
// the value of the largest possible exponent.
|
||||||
|
const int bias = -llvm::APFloat::semanticsMaxExponent(dest_flt_semantics);
|
||||||
|
llvm::Value* exponent =
|
||||||
|
Sub(llvm::ConstantInt::get(cttz->getType(), -bias - 1), cttz);
|
||||||
|
|
||||||
|
// Now just slot everything into place! The `Trunc` is here because
|
||||||
|
// raw_value may be larger than our float destination.
|
||||||
|
elem_value =
|
||||||
|
BitCast(Trunc(Or(Shl(exponent, significand_bits), significand),
|
||||||
|
b_->getIntNTy(elem_ir_ty->getPrimitiveSizeInBits())),
|
||||||
|
elem_ir_ty);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert the value for the requested distribution.
|
// Convert the value for the requested distribution.
|
||||||
|
@ -216,8 +216,11 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
|||||||
llvm_ir::ElementGenerator MakePhiloxRngElementGenerator(
|
llvm_ir::ElementGenerator MakePhiloxRngElementGenerator(
|
||||||
const HloInstruction* hlo,
|
const HloInstruction* hlo,
|
||||||
const HloToElementGeneratorMap& operand_to_generator);
|
const HloToElementGeneratorMap& operand_to_generator);
|
||||||
|
|
||||||
// Converts the raw value generated by a random number generation algorithm
|
// Converts the raw value generated by a random number generation algorithm
|
||||||
// to the distribution requested by the RNG HloInstruction.
|
// to the distribution requested by the RNG HloInstruction.
|
||||||
|
//
|
||||||
|
// Precondition: raw_value has at least as many bits as hlo's element type.
|
||||||
StatusOr<llvm::Value*> ConvertValueForDistribution(
|
StatusOr<llvm::Value*> ConvertValueForDistribution(
|
||||||
const HloInstruction* hlo,
|
const HloInstruction* hlo,
|
||||||
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
|
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
|
||||||
|
@ -765,6 +765,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
|
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
|
||||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||||
"//tensorflow/compiler/xla/service:sort_simplifier",
|
"//tensorflow/compiler/xla/service:sort_simplifier",
|
||||||
|
"//tensorflow/compiler/xla/service:stable_sort_expander",
|
||||||
"//tensorflow/compiler/xla/service:transpose_folding",
|
"//tensorflow/compiler/xla/service:transpose_folding",
|
||||||
"//tensorflow/compiler/xla/service:tuple_simplifier",
|
"//tensorflow/compiler/xla/service:tuple_simplifier",
|
||||||
"//tensorflow/compiler/xla/service:while_loop_constant_sinking",
|
"//tensorflow/compiler/xla/service:while_loop_constant_sinking",
|
||||||
|
@ -82,6 +82,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
|
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
|
||||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||||
#include "tensorflow/compiler/xla/service/sort_simplifier.h"
|
#include "tensorflow/compiler/xla/service/sort_simplifier.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/stable_sort_expander.h"
|
||||||
#include "tensorflow/compiler/xla/service/transpose_folding.h"
|
#include "tensorflow/compiler/xla/service/transpose_folding.h"
|
||||||
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
|
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
|
||||||
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
|
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
|
||||||
@ -195,6 +196,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
|||||||
pipeline.AddPass<ConvolutionGroupConverter>(
|
pipeline.AddPass<ConvolutionGroupConverter>(
|
||||||
cost_model,
|
cost_model,
|
||||||
/*convert_batch_groups_only=*/true);
|
/*convert_batch_groups_only=*/true);
|
||||||
|
// Expand the sort op to support stable sorting if required.
|
||||||
|
pipeline.AddPass<StableSortExpander>();
|
||||||
// Convert BF16 operations to F32 operations so that the GPU backend can
|
// Convert BF16 operations to F32 operations so that the GPU backend can
|
||||||
// support BF16 operations without directly implementing a BF16 lowering for
|
// support BF16 operations without directly implementing a BF16 lowering for
|
||||||
// most ops.
|
// most ops.
|
||||||
|
@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
|
|||||||
option cc_enable_arenas = true;
|
option cc_enable_arenas = true;
|
||||||
|
|
||||||
// Serialization of HloInstruction.
|
// Serialization of HloInstruction.
|
||||||
// Next ID: 60
|
// Next ID: 61
|
||||||
message HloInstructionProto {
|
message HloInstructionProto {
|
||||||
reserved 10;
|
reserved 10;
|
||||||
reserved "parameter_name";
|
reserved "parameter_name";
|
||||||
@ -175,6 +175,9 @@ message HloInstructionProto {
|
|||||||
// partners.
|
// partners.
|
||||||
bool is_host_transfer = 47;
|
bool is_host_transfer = 47;
|
||||||
|
|
||||||
|
// Whether this Sort instruction should be stable.
|
||||||
|
bool is_stable = 60;
|
||||||
|
|
||||||
xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;
|
xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;
|
||||||
|
|
||||||
// Precision configuration for the instruction. Has backend-specific meaning.
|
// Precision configuration for the instruction. Has backend-specific meaning.
|
||||||
|
@ -275,7 +275,7 @@ StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
|
|||||||
|
|
||||||
StatusOr<HloInstruction*> MakeSortHlo(
|
StatusOr<HloInstruction*> MakeSortHlo(
|
||||||
const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
|
const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
|
||||||
int64 dimension_to_sort, HloComputation::Builder* builder,
|
int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
|
||||||
HloModule* module) {
|
HloModule* module) {
|
||||||
CHECK(!operands.empty()) << "Sort Hlo requires at least one operand.";
|
CHECK(!operands.empty()) << "Sort Hlo requires at least one operand.";
|
||||||
HloComputation* compare_computation;
|
HloComputation* compare_computation;
|
||||||
@ -293,7 +293,7 @@ StatusOr<HloInstruction*> MakeSortHlo(
|
|||||||
compare_computation =
|
compare_computation =
|
||||||
module->DeepCloneComputation(new_module->entry_computation(), &context);
|
module->DeepCloneComputation(new_module->entry_computation(), &context);
|
||||||
return builder->AddInstruction(HloInstruction::CreateSort(
|
return builder->AddInstruction(HloInstruction::CreateSort(
|
||||||
sort_shape, dimension_to_sort, operands, compare_computation));
|
sort_shape, dimension_to_sort, operands, compare_computation, is_stable));
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
|
StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
|
||||||
|
@ -126,10 +126,10 @@ StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
|
|||||||
// Creates a Sort HLO instruction and adds it to the computation containing the
|
// Creates a Sort HLO instruction and adds it to the computation containing the
|
||||||
// operands. All operands must be in the same computation. Also creates a
|
// operands. All operands must be in the same computation. Also creates a
|
||||||
// default compare sub-computation which sorts the first operand into ascending
|
// default compare sub-computation which sorts the first operand into ascending
|
||||||
// order.
|
// order. 'is_stable' specifies whether the sorting should be stable.
|
||||||
StatusOr<HloInstruction*> MakeSortHlo(
|
StatusOr<HloInstruction*> MakeSortHlo(
|
||||||
const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
|
const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
|
||||||
int64 dimension_to_sort, HloComputation::Builder* builder,
|
int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
|
||||||
HloModule* module);
|
HloModule* module);
|
||||||
|
|
||||||
// Creates an R1 Constant HLO instruction of the given PrimitiveType with the
|
// Creates an R1 Constant HLO instruction of the given PrimitiveType with the
|
||||||
|
@ -2363,7 +2363,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
|
|||||||
auto keys = builder.AddInstruction(
|
auto keys = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, keys_shape, "keys"));
|
HloInstruction::CreateParameter(0, keys_shape, "keys"));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto* sort, MakeSortHlo(keys_shape, {keys}, -1, &builder, module_.get()));
|
auto* sort, MakeSortHlo(keys_shape, {keys}, -1, /*is_stable=*/false,
|
||||||
|
&builder, module_.get()));
|
||||||
|
|
||||||
computation_ = module_->AddEntryComputation(builder.Build());
|
computation_ = module_->AddEntryComputation(builder.Build());
|
||||||
RunAnalysis();
|
RunAnalysis();
|
||||||
@ -2385,7 +2386,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
|
|||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto* sort,
|
auto* sort,
|
||||||
MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
|
MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
|
||||||
{keys, values}, 0, &builder, module_.get()));
|
{keys, values}, 0, /*is_stable=*/false, &builder,
|
||||||
|
module_.get()));
|
||||||
|
|
||||||
computation_ = module_->AddEntryComputation(builder.Build());
|
computation_ = module_->AddEntryComputation(builder.Build());
|
||||||
RunAnalysis();
|
RunAnalysis();
|
||||||
|
@ -2670,12 +2670,25 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
const Literal& high =
|
const Literal& high =
|
||||||
parent_->GetEvaluatedLiteralFor(random->operand(1));
|
parent_->GetEvaluatedLiteralFor(random->operand(1));
|
||||||
|
|
||||||
std::uniform_real_distribution<NativeT> generator(
|
// std::uniform_real_distribution(a, b) can sometimes return a value
|
||||||
low.Get<NativeT>({}), high.Get<NativeT>({}));
|
// equal to b. Unclear if this is a spec bug or an implementation bug
|
||||||
|
// or WAI [0] [1] [2]. Anyway for our purposes we want a half-open
|
||||||
|
// interval, so we have to re-sample if we get `b` out.
|
||||||
|
//
|
||||||
|
// [0] https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176
|
||||||
|
// [1] https://bugs.llvm.org/show_bug.cgi?id=18767
|
||||||
|
// [2] http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524
|
||||||
|
auto low_val = low.Get<NativeT>({});
|
||||||
|
auto high_val = high.Get<NativeT>({});
|
||||||
|
std::uniform_real_distribution<NativeT> generator(low_val, high_val);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) {
|
result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) {
|
||||||
return generator(parent_->engine_);
|
while (true) {
|
||||||
|
NativeT v = generator(parent_->engine_);
|
||||||
|
if (v != high_val) {
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
}
|
||||||
}));
|
}));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -214,7 +214,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
<< proto.called_computation_ids_size();
|
<< proto.called_computation_ids_size();
|
||||||
auto sort_operands = all_operands();
|
auto sort_operands = all_operands();
|
||||||
instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
|
instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
|
||||||
computations(0));
|
computations(0), proto.is_stable());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kTranspose:
|
case HloOpcode::kTranspose:
|
||||||
@ -1170,9 +1170,10 @@ HloInstruction::CreateBroadcastSequence(
|
|||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
|
||||||
const Shape& shape, int64 dimension,
|
const Shape& shape, int64 dimension,
|
||||||
absl::Span<HloInstruction* const> operands, HloComputation* compare) {
|
absl::Span<HloInstruction* const> operands, HloComputation* compare,
|
||||||
|
bool is_stable) {
|
||||||
return absl::make_unique<HloSortInstruction>(shape, dimension, operands,
|
return absl::make_unique<HloSortInstruction>(shape, dimension, operands,
|
||||||
compare);
|
compare, is_stable);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
|
||||||
|
@ -384,6 +384,14 @@ class HloInstruction {
|
|||||||
|
|
||||||
// Creates a random number generation instruction that fills a shape with
|
// Creates a random number generation instruction that fills a shape with
|
||||||
// random numbers from a given distribution.
|
// random numbers from a given distribution.
|
||||||
|
//
|
||||||
|
// The parameters to the instruction are interpreted as follows:
|
||||||
|
//
|
||||||
|
// - If `distribution` is RNG_UNIFORM, generates a number in range
|
||||||
|
// [param0, param1).
|
||||||
|
//
|
||||||
|
// - If `distribution` is RNG_NORMAL, generates a normally-distributed value
|
||||||
|
// with mean `param0` and standard deviation `param1`.
|
||||||
static std::unique_ptr<HloInstruction> CreateRng(
|
static std::unique_ptr<HloInstruction> CreateRng(
|
||||||
const Shape& shape, RandomDistribution distribution,
|
const Shape& shape, RandomDistribution distribution,
|
||||||
absl::Span<HloInstruction* const> parameters);
|
absl::Span<HloInstruction* const> parameters);
|
||||||
@ -678,10 +686,11 @@ class HloInstruction {
|
|||||||
// comparisons in the sorting algorithm. 'compare' gets 2 * n parameters,
|
// comparisons in the sorting algorithm. 'compare' gets 2 * n parameters,
|
||||||
// where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at
|
// where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at
|
||||||
// specific index positions which should be compared, and should return a
|
// specific index positions which should be compared, and should return a
|
||||||
// PRED.
|
// PRED. 'is_stable' specifies whether stable sorting is required.
|
||||||
static std::unique_ptr<HloInstruction> CreateSort(
|
static std::unique_ptr<HloInstruction> CreateSort(
|
||||||
const Shape& shape, int64 dimension,
|
const Shape& shape, int64 dimension,
|
||||||
absl::Span<HloInstruction* const> operands, HloComputation* compare);
|
absl::Span<HloInstruction* const> operands, HloComputation* compare,
|
||||||
|
bool is_stable);
|
||||||
|
|
||||||
// Creates a while instruction, given a condition computation, a body
|
// Creates a while instruction, given a condition computation, a body
|
||||||
// computation, and the initial value for the input of the computations. For
|
// computation, and the initial value for the input of the computations. For
|
||||||
@ -1286,6 +1295,9 @@ class HloInstruction {
|
|||||||
backend_config_ = std::move(config_str);
|
backend_config_ = std::move(config_str);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_default_config() const { return is_default_config_; }
|
||||||
|
void set_default_config() { is_default_config_ = true; }
|
||||||
|
|
||||||
// Returns a string representation of a proto in the format used by
|
// Returns a string representation of a proto in the format used by
|
||||||
// raw_backend_config_string.
|
// raw_backend_config_string.
|
||||||
//
|
//
|
||||||
@ -1734,6 +1746,10 @@ class HloInstruction {
|
|||||||
// HLO. See the documentation on backend_config().
|
// HLO. See the documentation on backend_config().
|
||||||
string backend_config_;
|
string backend_config_;
|
||||||
|
|
||||||
|
// This field is assigned to true when backend_config_ is assigned to
|
||||||
|
// a default configuration.
|
||||||
|
bool is_default_config_ = false;
|
||||||
|
|
||||||
// String identifier for instruction.
|
// String identifier for instruction.
|
||||||
string name_;
|
string name_;
|
||||||
|
|
||||||
|
@ -659,8 +659,11 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
|
|||||||
|
|
||||||
HloSortInstruction::HloSortInstruction(
|
HloSortInstruction::HloSortInstruction(
|
||||||
const Shape& shape, int64 dimension,
|
const Shape& shape, int64 dimension,
|
||||||
absl::Span<HloInstruction* const> operands, HloComputation* compare)
|
absl::Span<HloInstruction* const> operands, HloComputation* compare,
|
||||||
: HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) {
|
bool is_stable)
|
||||||
|
: HloInstruction(HloOpcode::kSort, shape),
|
||||||
|
dimensions_({dimension}),
|
||||||
|
is_stable_(is_stable) {
|
||||||
for (auto* value : operands) {
|
for (auto* value : operands) {
|
||||||
AppendOperand(value);
|
AppendOperand(value);
|
||||||
}
|
}
|
||||||
@ -672,12 +675,18 @@ HloInstructionProto HloSortInstruction::ToProto() const {
|
|||||||
for (int64 dimension : dimensions_) {
|
for (int64 dimension : dimensions_) {
|
||||||
proto.add_dimensions(dimension);
|
proto.add_dimensions(dimension);
|
||||||
}
|
}
|
||||||
|
proto.set_is_stable(is_stable());
|
||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
|
std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
|
||||||
const HloPrintOptions& options) const {
|
const HloPrintOptions& options) const {
|
||||||
return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
|
std::vector<string> attrs;
|
||||||
|
attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}"));
|
||||||
|
if (is_stable()) {
|
||||||
|
attrs.push_back("is_stable=true");
|
||||||
|
}
|
||||||
|
return attrs;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HloSortInstruction::IdenticalSlowPath(
|
bool HloSortInstruction::IdenticalSlowPath(
|
||||||
@ -688,14 +697,17 @@ bool HloSortInstruction::IdenticalSlowPath(
|
|||||||
if (dimensions() != casted_other.dimensions()) {
|
if (dimensions() != casted_other.dimensions()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (is_stable() != casted_other.is_stable()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return eq_computations(to_apply(), other.to_apply());
|
return eq_computations(to_apply(), other.to_apply());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
|
std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
|
||||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||||
HloCloneContext* context) const {
|
HloCloneContext* context) const {
|
||||||
return absl::make_unique<HloSortInstruction>(shape, dimensions(0),
|
return absl::make_unique<HloSortInstruction>(
|
||||||
new_operands, to_apply());
|
shape, dimensions(0), new_operands, to_apply(), is_stable());
|
||||||
}
|
}
|
||||||
|
|
||||||
HloTransposeInstruction::HloTransposeInstruction(
|
HloTransposeInstruction::HloTransposeInstruction(
|
||||||
|
@ -447,7 +447,7 @@ class HloSortInstruction : public HloInstruction {
|
|||||||
public:
|
public:
|
||||||
explicit HloSortInstruction(const Shape& shape, int64 dimension,
|
explicit HloSortInstruction(const Shape& shape, int64 dimension,
|
||||||
absl::Span<HloInstruction* const> operands,
|
absl::Span<HloInstruction* const> operands,
|
||||||
HloComputation* compare);
|
HloComputation* compare, bool is_stable);
|
||||||
// Returns the dimension sizes or numbers associated with this instruction.
|
// Returns the dimension sizes or numbers associated with this instruction.
|
||||||
const std::vector<int64>& dimensions() const override { return dimensions_; }
|
const std::vector<int64>& dimensions() const override { return dimensions_; }
|
||||||
int64 dimensions(int64 index) const override { return dimensions()[index]; }
|
int64 dimensions(int64 index) const override { return dimensions()[index]; }
|
||||||
@ -460,6 +460,7 @@ class HloSortInstruction : public HloInstruction {
|
|||||||
HloInstruction* mutable_keys() { return mutable_operand(0); }
|
HloInstruction* mutable_keys() { return mutable_operand(0); }
|
||||||
// Returns the number of value operands.
|
// Returns the number of value operands.
|
||||||
int64 values_count() const { return operand_count() - 1; }
|
int64 values_count() const { return operand_count() - 1; }
|
||||||
|
bool is_stable() const { return is_stable_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<string> ExtraAttributesToStringImpl(
|
std::vector<string> ExtraAttributesToStringImpl(
|
||||||
@ -474,6 +475,7 @@ class HloSortInstruction : public HloInstruction {
|
|||||||
HloCloneContext* context) const override;
|
HloCloneContext* context) const override;
|
||||||
|
|
||||||
std::vector<int64> dimensions_;
|
std::vector<int64> dimensions_;
|
||||||
|
bool is_stable_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class HloTransposeInstruction : public HloInstruction {
|
class HloTransposeInstruction : public HloInstruction {
|
||||||
|
@ -895,6 +895,8 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
optional<std::vector<int64>> dimensions;
|
optional<std::vector<int64>> dimensions;
|
||||||
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
|
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
|
||||||
&dimensions};
|
&dimensions};
|
||||||
|
optional<bool> is_stable = false;
|
||||||
|
attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable};
|
||||||
optional<HloComputation*> to_apply;
|
optional<HloComputation*> to_apply;
|
||||||
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
|
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
|
||||||
&to_apply};
|
&to_apply};
|
||||||
@ -902,8 +904,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
dimensions->size() != 1) {
|
dimensions->size() != 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
instruction = builder->AddInstruction(HloInstruction::CreateSort(
|
instruction = builder->AddInstruction(
|
||||||
shape, dimensions->at(0), operands, to_apply.value()));
|
HloInstruction::CreateSort(shape, dimensions->at(0), operands,
|
||||||
|
to_apply.value(), is_stable.value()));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kTuple: {
|
case HloOpcode::kTuple: {
|
||||||
|
@ -1145,6 +1145,24 @@ ENTRY Sort {
|
|||||||
ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare
|
ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare
|
||||||
}
|
}
|
||||||
|
|
||||||
|
)"
|
||||||
|
},
|
||||||
|
// Sort (Key) is_stable=true
|
||||||
|
{
|
||||||
|
"SortKeyStable",
|
||||||
|
R"(HloModule sort
|
||||||
|
|
||||||
|
compare {
|
||||||
|
p.0.lhs = f32[] parameter(0)
|
||||||
|
p.0.rhs = f32[] parameter(1)
|
||||||
|
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY Sort {
|
||||||
|
x = f32[1024]{0} parameter(0)
|
||||||
|
ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare
|
||||||
|
}
|
||||||
|
|
||||||
)"
|
)"
|
||||||
},
|
},
|
||||||
// Conditional
|
// Conditional
|
||||||
|
@ -176,7 +176,7 @@ StatusOr<Literal> HloRunner::Execute(
|
|||||||
TransferLiteralsToDevice(arguments));
|
TransferLiteralsToDevice(arguments));
|
||||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
|
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
|
||||||
ExecuteWithDeviceBuffers(
|
ExecuteWithDeviceBuffers(
|
||||||
/*module=*/std::move(executable),
|
/*executable=*/executable.get(),
|
||||||
/*arguments=*/argument_buffers,
|
/*arguments=*/argument_buffers,
|
||||||
/*profile=*/profile));
|
/*profile=*/profile));
|
||||||
return TransferLiteralFromDevice(result);
|
return TransferLiteralFromDevice(result);
|
||||||
@ -235,7 +235,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
||||||
std::unique_ptr<Executable> executable,
|
Executable* executable,
|
||||||
const absl::Span<const ShapedBuffer* const> arguments,
|
const absl::Span<const ShapedBuffer* const> arguments,
|
||||||
ExecutionProfile* profile) {
|
ExecutionProfile* profile) {
|
||||||
// Get service run options.
|
// Get service run options.
|
||||||
@ -254,7 +254,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
||||||
std::unique_ptr<Executable> executable,
|
Executable* executable,
|
||||||
const absl::Span<const ScopedShapedBuffer> arguments,
|
const absl::Span<const ScopedShapedBuffer> arguments,
|
||||||
ExecutionProfile* profile) {
|
ExecutionProfile* profile) {
|
||||||
std::vector<const ShapedBuffer*> argument_pointers;
|
std::vector<const ShapedBuffer*> argument_pointers;
|
||||||
|
@ -144,13 +144,16 @@ class HloRunner {
|
|||||||
const absl::Span<const ScopedShapedBuffer> arguments,
|
const absl::Span<const ScopedShapedBuffer> arguments,
|
||||||
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
|
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
|
||||||
|
|
||||||
|
// In the following two calls, "executable" is not a unique_ptr to allow
|
||||||
|
// reuse of the Executable. This call may update the profile information in
|
||||||
|
// *executable.
|
||||||
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
|
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
|
||||||
std::unique_ptr<Executable> executable,
|
Executable* executable,
|
||||||
const absl::Span<const ShapedBuffer* const> arguments,
|
const absl::Span<const ShapedBuffer* const> arguments,
|
||||||
ExecutionProfile* profile = nullptr);
|
ExecutionProfile* profile = nullptr);
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
|
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
|
||||||
std::unique_ptr<Executable> executable,
|
Executable* executable,
|
||||||
const absl::Span<const ScopedShapedBuffer> arguments,
|
const absl::Span<const ScopedShapedBuffer> arguments,
|
||||||
ExecutionProfile* profile = nullptr);
|
ExecutionProfile* profile = nullptr);
|
||||||
|
|
||||||
|
@ -36,6 +36,9 @@ StatusOr<bool> OpExpanderPass::Run(HloModule* module) {
|
|||||||
for (HloInstruction* inst : matching_instructions) {
|
for (HloInstruction* inst : matching_instructions) {
|
||||||
TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root,
|
TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root,
|
||||||
ExpandInstruction(inst));
|
ExpandInstruction(inst));
|
||||||
|
if (expanded_root == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root));
|
TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,7 +33,9 @@ class OpExpanderPass : public HloModulePass {
|
|||||||
// Returns `true` if `instruction` should be expanded by this pass.
|
// Returns `true` if `instruction` should be expanded by this pass.
|
||||||
virtual bool InstructionMatchesPattern(HloInstruction* instruction) = 0;
|
virtual bool InstructionMatchesPattern(HloInstruction* instruction) = 0;
|
||||||
|
|
||||||
// Returns a replacement for `instruction`.
|
// Returns a replacement for `instruction`, or nullptr if no replacement is
|
||||||
|
// neeeded (e.g. only the to_apply subcomputation of the instruction was
|
||||||
|
// modified).
|
||||||
virtual StatusOr<HloInstruction*> ExpandInstruction(
|
virtual StatusOr<HloInstruction*> ExpandInstruction(
|
||||||
HloInstruction* instruction) = 0;
|
HloInstruction* instruction) = 0;
|
||||||
};
|
};
|
||||||
|
204
tensorflow/compiler/xla/service/stable_sort_expander.cc
Normal file
204
tensorflow/compiler/xla/service/stable_sort_expander.cc
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/stable_sort_expander.h"
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// Looks for a iota operand that can be used as tie breaker in the computation.
|
||||||
|
// If no matching iota operand is found, a iota operand is added to Sort. The
|
||||||
|
// comparison computation is adjusted to break ties using the values from the
|
||||||
|
// iota operand.
|
||||||
|
StatusOr<HloInstruction*> StableSortExpander::ExpandInstruction(
|
||||||
|
HloInstruction* instruction) {
|
||||||
|
auto* sort = Cast<HloSortInstruction>(instruction);
|
||||||
|
HloComputation* computation = sort->parent();
|
||||||
|
|
||||||
|
HloInstruction* expanded_sort = nullptr;
|
||||||
|
absl::flat_hash_set<int64> used_indices;
|
||||||
|
int64 iota_index = -1;
|
||||||
|
for (const HloInstruction* operand : sort->operands()) {
|
||||||
|
// We can only use the iota operand if it has an iota dimension which is the
|
||||||
|
// same as the dimension to sort. Also it should have an integral type that
|
||||||
|
// is large enough for the number of elements in the sort dimension. For
|
||||||
|
// now, we only allow S32, because we expect to find a S32 iota operand for
|
||||||
|
// all Sort ops which are created by TopK.
|
||||||
|
// TODO(b/122298745): Also support other types.
|
||||||
|
if (operand->opcode() == HloOpcode::kIota &&
|
||||||
|
Cast<HloIotaInstruction>(operand)->iota_dimension() ==
|
||||||
|
sort->sort_dimension() &&
|
||||||
|
operand->shape().element_type() == S32) {
|
||||||
|
iota_index = sort->operand_index(operand);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is currently no iota operand which we could use for making the
|
||||||
|
// sort stable, we will have to add a new such operand.
|
||||||
|
if (iota_index == -1) {
|
||||||
|
Shape iota_shape = sort->operand(0)->shape();
|
||||||
|
// We might need to use S64 if the number of elements in the sort dimension
|
||||||
|
// is bigger than 2^31 - 1.
|
||||||
|
// TODO(b/122298745): Handle Sort ops where S32 is too small for the number
|
||||||
|
// of elements in the sort dimension.
|
||||||
|
if (iota_shape.dimensions(sort->sort_dimension()) >
|
||||||
|
std::numeric_limits<int32>::max()) {
|
||||||
|
return Unimplemented(
|
||||||
|
"Stable sorting of more than 2^31-1 elements is not implemented");
|
||||||
|
}
|
||||||
|
iota_shape.set_element_type(S32);
|
||||||
|
auto iota = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateIota(iota_shape, sort->sort_dimension()));
|
||||||
|
|
||||||
|
// Create a new comparator.
|
||||||
|
auto comparator = sort->to_apply();
|
||||||
|
absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
|
||||||
|
replacements;
|
||||||
|
std::vector<std::unique_ptr<HloInstruction>> extra_parameters;
|
||||||
|
std::vector<HloInstruction*> extra_parameter_ptrs;
|
||||||
|
Shape scalar_shape = ShapeUtil::MakeShape(S32, {});
|
||||||
|
extra_parameters.push_back(HloInstruction::CreateParameter(
|
||||||
|
sort->operand_count() * 2, scalar_shape,
|
||||||
|
absl::StrCat("p.", sort->operand_count(), ".lhs")));
|
||||||
|
extra_parameter_ptrs.push_back(extra_parameters.back().get());
|
||||||
|
extra_parameters.push_back(HloInstruction::CreateParameter(
|
||||||
|
sort->operand_count() * 2 + 1, scalar_shape,
|
||||||
|
absl::StrCat("p.", sort->operand_count(), ".rhs")));
|
||||||
|
extra_parameter_ptrs.push_back(extra_parameters.back().get());
|
||||||
|
sort->set_to_apply(sort->GetModule()->AddEmbeddedComputation(
|
||||||
|
comparator->CloneWithReplacements(std::move(replacements),
|
||||||
|
extra_parameter_ptrs)));
|
||||||
|
|
||||||
|
// Replace the original sort op.
|
||||||
|
std::vector<HloInstruction*> new_operands(sort->operands().begin(),
|
||||||
|
sort->operands().end());
|
||||||
|
new_operands.push_back(iota);
|
||||||
|
std::vector<Shape> new_shapes = sort->operand_count() == 1
|
||||||
|
? std::vector<Shape>{sort->shape()}
|
||||||
|
: sort->shape().tuple_shapes();
|
||||||
|
new_shapes.push_back(iota_shape);
|
||||||
|
Shape new_sort_shape = ShapeUtil::MakeTupleShape(new_shapes);
|
||||||
|
HloInstruction* new_sort = computation->AddInstruction(
|
||||||
|
sort->CloneWithNewOperands(new_sort_shape, new_operands));
|
||||||
|
|
||||||
|
// Add a "wrapper" around the new sort op to make sure we have the same
|
||||||
|
// shape as before. For the rank 1 case, we only need a GetTupleElement,
|
||||||
|
// otherwise we create a Tuple consisting of GetTupleElements of the new
|
||||||
|
// sort.
|
||||||
|
std::vector<HloInstruction*> tuple_elements;
|
||||||
|
tuple_elements.reserve(sort->operand_count());
|
||||||
|
for (int64 i = 0; i < sort->operand_count(); ++i) {
|
||||||
|
tuple_elements.push_back(
|
||||||
|
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||||
|
sort->operand(i)->shape(), new_sort, i)));
|
||||||
|
}
|
||||||
|
expanded_sort = tuple_elements[0];
|
||||||
|
if (tuple_elements.size() > 1) {
|
||||||
|
expanded_sort = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateTuple(tuple_elements));
|
||||||
|
}
|
||||||
|
sort = Cast<HloSortInstruction>(new_sort);
|
||||||
|
iota_index = sort->operand_count() - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify the computation to break ties using the iota operand.
|
||||||
|
auto comparator = sort->to_apply();
|
||||||
|
std::vector<HloInstruction*> instructions_postorder =
|
||||||
|
comparator->MakeInstructionPostOrder();
|
||||||
|
absl::flat_hash_map<HloInstruction*, HloInstruction*> replacements;
|
||||||
|
// Look up instr in the replacements map, and return either the replacement,
|
||||||
|
// or instr, if the replacement isn't present.
|
||||||
|
auto replace = [&](HloInstruction* instr) {
|
||||||
|
auto it = replacements.find(instr);
|
||||||
|
if (it == replacements.end()) {
|
||||||
|
return instr;
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
};
|
||||||
|
HloInstruction* old_root = comparator->root_instruction();
|
||||||
|
// The comparison computation gets 2 * n parameters (n being the number of
|
||||||
|
// operands of Sort), where parameters 2 * i and 2 * i + 1 correspond to two
|
||||||
|
// different scalars of operand i of Sort which are to be compared. The
|
||||||
|
// comparison computation should induce a strict weak order, so if
|
||||||
|
// to_apply(p1.lhs, p1.rhs, ..., pn.lhs, pn.rhs) is equal to
|
||||||
|
// to_apply(p1.rhs, p1.lhs, ..., pn.rhs, pn.lhs), we can conclude that the
|
||||||
|
// values to be compared are equivalent, and perform a tie-breaker comparison.
|
||||||
|
//
|
||||||
|
// We clone each instruction with at least one operand, but use as new
|
||||||
|
// operands of the instruction the replacements of the original operands.
|
||||||
|
// Parameter 2 * i is replaced by parameter 2 * i + 1 and vice versa. This
|
||||||
|
// should make sure that the cloned root instruction gives the result of the
|
||||||
|
// comparison computation when being called with each scalar pair reversed.
|
||||||
|
// parameters corresponding to the iota operand.
|
||||||
|
for (int64 i = 0; i < comparator->num_parameters(); ++i) {
|
||||||
|
replacements[comparator->parameter_instruction(i)] =
|
||||||
|
comparator->parameter_instruction(i ^ 1);
|
||||||
|
}
|
||||||
|
HloInstruction* cloned_root = nullptr;
|
||||||
|
for (HloInstruction* inst : instructions_postorder) {
|
||||||
|
if (inst->operand_count() == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::vector<HloInstruction*> new_operands;
|
||||||
|
new_operands.reserve(inst->operand_count());
|
||||||
|
for (HloInstruction* operand : inst->operands()) {
|
||||||
|
new_operands.push_back(replace(operand));
|
||||||
|
}
|
||||||
|
auto new_instruction =
|
||||||
|
inst->CloneWithNewOperands(inst->shape(), new_operands);
|
||||||
|
replacements[inst] = new_instruction.get();
|
||||||
|
if (inst == old_root) {
|
||||||
|
cloned_root = new_instruction.get();
|
||||||
|
}
|
||||||
|
comparator->AddInstruction(std::move(new_instruction));
|
||||||
|
}
|
||||||
|
CHECK_NE(cloned_root, nullptr);
|
||||||
|
Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
|
||||||
|
HloInstruction* same =
|
||||||
|
comparator->AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
scalar_pred, HloOpcode::kEq, old_root, cloned_root));
|
||||||
|
HloInstruction* tie_breaker =
|
||||||
|
comparator->AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
scalar_pred, HloOpcode::kLt,
|
||||||
|
comparator->parameter_instruction(2 * iota_index),
|
||||||
|
comparator->parameter_instruction(2 * iota_index + 1)));
|
||||||
|
HloInstruction* new_root =
|
||||||
|
comparator->AddInstruction(HloInstruction::CreateTernary(
|
||||||
|
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker,
|
||||||
|
old_root));
|
||||||
|
comparator->set_root_instruction(new_root);
|
||||||
|
|
||||||
|
return expanded_sort;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StableSortExpander::InstructionMatchesPattern(
|
||||||
|
HloInstruction* instruction) {
|
||||||
|
return instruction->opcode() == HloOpcode::kSort &&
|
||||||
|
Cast<HloSortInstruction>(instruction)->is_stable();
|
||||||
|
}
|
||||||
|
} // namespace xla
|
42
tensorflow/compiler/xla/service/stable_sort_expander.h
Normal file
42
tensorflow/compiler/xla/service/stable_sort_expander.h
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// HLO pass which expands Sort ops that have the is_stable field set to true
|
||||||
|
// into equivalent Sort ops which guarantee stable sorting without relying on
|
||||||
|
// the is_stable field.
|
||||||
|
class StableSortExpander : public OpExpanderPass {
|
||||||
|
public:
|
||||||
|
absl::string_view name() const override { return "stable-sort-expander"; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool InstructionMatchesPattern(HloInstruction* instruction) override;
|
||||||
|
StatusOr<HloInstruction*> ExpandInstruction(
|
||||||
|
HloInstruction* instruction) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_
|
358
tensorflow/compiler/xla/service/stable_sort_expander_test.cc
Normal file
358
tensorflow/compiler/xla/service/stable_sort_expander_test.cc
Normal file
@ -0,0 +1,358 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/stable_sort_expander.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
|
||||||
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
namespace m = match;
|
||||||
|
|
||||||
|
using StableSortExpanderTest = HloTestBase;
|
||||||
|
|
||||||
|
// Checks whether 'a' and 'b' are roots of equivalent computations, except that
|
||||||
|
// parameters 2 * i and 2 * i + 1 are switched.
|
||||||
|
bool IsSameComputationExceptParams(const HloInstruction* a,
|
||||||
|
const HloInstruction* b) {
|
||||||
|
if (a->opcode() != b->opcode() || a->operand_count() != b->operand_count()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (a->opcode() == HloOpcode::kParameter) {
|
||||||
|
// Check that parameters were switched.
|
||||||
|
return a->parameter_number() == (b->parameter_number() ^ 1);
|
||||||
|
}
|
||||||
|
// If the operation has no operands, it should actually be the same.
|
||||||
|
if (a->operand_count() == 0) {
|
||||||
|
return a == b;
|
||||||
|
}
|
||||||
|
// Otherwise recursively compare all operands.
|
||||||
|
for (int64 i = 0; i < a->operand_count(); ++i) {
|
||||||
|
if (!IsSameComputationExceptParams(a->operand(i), b->operand(i))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the comparison computation has been modified to add a tie breaker
|
||||||
|
// using 'iota_parameter'.
|
||||||
|
void CheckComputationHasTieBreaker(const HloInstruction* root,
|
||||||
|
int64 iota_parameter) {
|
||||||
|
// With the tie breaker, the root instruction should be
|
||||||
|
// Select(Eq(Comp(), CompReverse()), Lt(), Comp())
|
||||||
|
// with Comp() being the original comparison function, and CompReverse() being
|
||||||
|
// the copied comparison function where the parameters are reversed. Lt() is
|
||||||
|
// the tie breaker comparison using the Iota operand.
|
||||||
|
ASSERT_EQ(root->opcode(), HloOpcode::kSelect);
|
||||||
|
ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kEq);
|
||||||
|
|
||||||
|
// Check that the tie breaker instruction is correct.
|
||||||
|
EXPECT_THAT(root->operand(1),
|
||||||
|
GmockMatch(m::Lt(m::Parameter(iota_parameter * 2),
|
||||||
|
m::Parameter(iota_parameter * 2 + 1))));
|
||||||
|
EXPECT_EQ(root->operand(2), root->operand(0)->operand(0));
|
||||||
|
|
||||||
|
// Check that Comp() and CompReverse() are equivalent except that
|
||||||
|
// CompReverse() has reversed parameters.
|
||||||
|
EXPECT_TRUE(IsSameComputationExceptParams(root->operand(0)->operand(0),
|
||||||
|
root->operand(0)->operand(1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(StableSortExpanderTest, StabilizeSortReuseIotaOperand) {
|
||||||
|
const char* hlo_string = R"(
|
||||||
|
HloModule permutation_sort
|
||||||
|
|
||||||
|
compare {
|
||||||
|
p.0.lhs = f32[] parameter(0)
|
||||||
|
p.0.rhs = f32[] parameter(1)
|
||||||
|
p.1.lhs = s32[] parameter(2)
|
||||||
|
p.1.rhs = s32[] parameter(3)
|
||||||
|
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY sort_computation {
|
||||||
|
keys = f32[64,8732]{1,0} parameter(0)
|
||||||
|
values = s32[64,8732]{1,0} iota(), iota_dimension=1
|
||||||
|
sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
|
||||||
|
dimensions={1}, to_apply=compare, is_stable=true
|
||||||
|
ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
StableSortExpander stabilizer;
|
||||||
|
EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
|
||||||
|
m::Sort(m::Parameter(0), m::Iota()), 0)));
|
||||||
|
CheckComputationHasTieBreaker(
|
||||||
|
root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(StableSortExpanderTest,
|
||||||
|
StabilizeSortReuseIotaOperandComplicatedComparison) {
|
||||||
|
const char* hlo_string = R"(
|
||||||
|
HloModule permutation_sort
|
||||||
|
|
||||||
|
compare {
|
||||||
|
p.0.lhs = f32[] parameter(0)
|
||||||
|
p.0.rhs = f32[] parameter(1)
|
||||||
|
p.1.lhs = s32[] parameter(2)
|
||||||
|
p.1.rhs = s32[] parameter(3)
|
||||||
|
max = u32[] constant(2147483647)
|
||||||
|
zero = s32[] constant(0)
|
||||||
|
lhs.signed = s32[] bitcast-convert(p.0.lhs)
|
||||||
|
lhs.unsigned = u32[] bitcast-convert(p.0.lhs)
|
||||||
|
lhs.flipped = u32[] subtract(max, lhs.unsigned)
|
||||||
|
lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped)
|
||||||
|
lhs.is_negative = pred[] less-than(lhs.flipped.signed, zero)
|
||||||
|
lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed)
|
||||||
|
rhs.signed = s32[] bitcast-convert(p.0.rhs)
|
||||||
|
rhs.unsigned = u32[] bitcast-convert(p.0.rhs)
|
||||||
|
rhs.flipped = u32[] subtract(max, rhs.unsigned)
|
||||||
|
rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped)
|
||||||
|
rhs.is_negative = pred[] less-than(rhs.flipped.signed, zero)
|
||||||
|
rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed)
|
||||||
|
ROOT lt = pred[] less-than(lhs.converted, rhs.converted)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY sort_computation {
|
||||||
|
keys = f32[64,8732]{1,0} parameter(0)
|
||||||
|
values = s32[64,8732]{1,0} iota(), iota_dimension=1
|
||||||
|
sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
|
||||||
|
dimensions={1}, to_apply=compare, is_stable=true
|
||||||
|
ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
StableSortExpander stabilizer;
|
||||||
|
EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
|
||||||
|
m::Sort(m::Parameter(0), m::Iota()), 0)));
|
||||||
|
CheckComputationHasTieBreaker(
|
||||||
|
root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(StableSortExpanderTest, StabilizeSortAddIotaOperandAndChangeRoot) {
|
||||||
|
const char* hlo_string = R"(
|
||||||
|
HloModule permutation_sort
|
||||||
|
|
||||||
|
compare {
|
||||||
|
p.0.lhs = f32[] parameter(0)
|
||||||
|
p.0.rhs = f32[] parameter(1)
|
||||||
|
p.1.lhs = s32[] parameter(2)
|
||||||
|
p.1.rhs = s32[] parameter(3)
|
||||||
|
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY sort_computation {
|
||||||
|
keys = f32[64,8732]{1,0} parameter(0)
|
||||||
|
values = s32[64,8732]{1,0} parameter(1)
|
||||||
|
ROOT sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
|
||||||
|
dimensions={1}, to_apply=compare, is_stable=true
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
StableSortExpander stabilizer;
|
||||||
|
EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(
|
||||||
|
root, GmockMatch(m::Tuple(
|
||||||
|
m::GetTupleElement(
|
||||||
|
m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 0),
|
||||||
|
m::GetTupleElement(
|
||||||
|
m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 1))));
|
||||||
|
CheckComputationHasTieBreaker(
|
||||||
|
root->operand(0)->operand(0)->to_apply()->root_instruction(),
|
||||||
|
/*iota_parameter=*/2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(StableSortExpanderTest, HonorIsStableFlag) {
|
||||||
|
const char* hlo_string = R"(
|
||||||
|
HloModule permutation_sort
|
||||||
|
|
||||||
|
compare {
|
||||||
|
p.0.lhs = f32[] parameter(0)
|
||||||
|
p.0.rhs = f32[] parameter(1)
|
||||||
|
p.1.lhs = s32[] parameter(2)
|
||||||
|
p.1.rhs = s32[] parameter(3)
|
||||||
|
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY sort_computation {
|
||||||
|
keys = f32[64,8732]{1,0} parameter(0)
|
||||||
|
values = s32[64,8732]{1,0} iota(), iota_dimension=1
|
||||||
|
sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
|
||||||
|
dimensions={1}, to_apply=compare, is_stable=false
|
||||||
|
ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
StableSortExpander stabilizer;
|
||||||
|
EXPECT_FALSE(stabilizer.Run(module.get()).ValueOrDie());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(StableSortExpanderTest,
|
||||||
|
StabilizeSortDontReuseIotaOperandWrongDimension) {
|
||||||
|
const char* hlo_string = R"(
|
||||||
|
HloModule permutation_sort
|
||||||
|
|
||||||
|
compare {
|
||||||
|
p.0.lhs = f32[] parameter(0)
|
||||||
|
p.0.rhs = f32[] parameter(1)
|
||||||
|
p.1.lhs = s32[] parameter(2)
|
||||||
|
p.1.rhs = s32[] parameter(3)
|
||||||
|
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY sort_computation {
|
||||||
|
keys = f32[64,8732]{1,0} parameter(0)
|
||||||
|
values = s32[64,8732]{1,0} iota(), iota_dimension=0
|
||||||
|
sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
|
||||||
|
dimensions={1}, to_apply=compare, is_stable=true
|
||||||
|
ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
StableSortExpander stabilizer;
|
||||||
|
EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
|
||||||
|
// Simplify away the "wrapper" tuple around the new sort.
|
||||||
|
AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions(
|
||||||
|
[](const Shape&, const Shape&) { return false; }));
|
||||||
|
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
|
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
|
||||||
|
m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0)));
|
||||||
|
CheckComputationHasTieBreaker(
|
||||||
|
root->operand(0)->to_apply()->root_instruction(),
|
||||||
|
/*iota_parameter=*/2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(StableSortExpanderTest, StabilizeSortDontReuseIotaOperandWrongType) {
|
||||||
|
const char* hlo_string = R"(
|
||||||
|
HloModule permutation_sort
|
||||||
|
|
||||||
|
compare {
|
||||||
|
p.0.lhs = f32[] parameter(0)
|
||||||
|
p.0.rhs = f32[] parameter(1)
|
||||||
|
p.1.lhs = f32[] parameter(2)
|
||||||
|
p.1.rhs = f32[] parameter(3)
|
||||||
|
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY sort_computation {
|
||||||
|
keys = f32[64,8732]{1,0} parameter(0)
|
||||||
|
values = f32[64,8732]{1,0} iota(), iota_dimension=1
|
||||||
|
sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values),
|
||||||
|
dimensions={1}, to_apply=compare, is_stable=true
|
||||||
|
ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
StableSortExpander stabilizer;
|
||||||
|
EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
|
||||||
|
// Simplify away the "wrapper" tuple around the new sort.
|
||||||
|
AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions(
|
||||||
|
[](const Shape&, const Shape&) { return false; }));
|
||||||
|
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
|
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
|
||||||
|
m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0)));
|
||||||
|
CheckComputationHasTieBreaker(
|
||||||
|
root->operand(0)->to_apply()->root_instruction(),
|
||||||
|
/*iota_parameter=*/2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(StableSortExpanderTest, StabilizeSortR1) {
|
||||||
|
const char* hlo_string = R"(
|
||||||
|
HloModule permutation_sort
|
||||||
|
|
||||||
|
compare {
|
||||||
|
p.0.lhs = s32[] parameter(0)
|
||||||
|
p.0.rhs = s32[] parameter(1)
|
||||||
|
mask = s32[] constant(65535)
|
||||||
|
lhs = s32[] and(p.0.lhs, mask)
|
||||||
|
rhs = s32[] and(p.0.rhs, mask)
|
||||||
|
ROOT lt = pred[] less-than(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY sort_computation {
|
||||||
|
keys = s32[64,8732]{1,0} parameter(0)
|
||||||
|
ROOT sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare,
|
||||||
|
is_stable=true
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
StableSortExpander stabilizer;
|
||||||
|
EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
|
||||||
|
m::Sort(m::Parameter(0), m::Iota()), 0)));
|
||||||
|
CheckComputationHasTieBreaker(
|
||||||
|
root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(StableSortExpanderTest, StabilizeSortR1NoRoot) {
|
||||||
|
const char* hlo_string = R"(
|
||||||
|
HloModule permutation_sort
|
||||||
|
|
||||||
|
compare {
|
||||||
|
p.0.lhs = s32[] parameter(0)
|
||||||
|
p.0.rhs = s32[] parameter(1)
|
||||||
|
mask = s32[] constant(65535)
|
||||||
|
lhs = s32[] and(p.0.lhs, mask)
|
||||||
|
rhs = s32[] and(p.0.rhs, mask)
|
||||||
|
ROOT lt = pred[] less-than(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY sort_computation {
|
||||||
|
keys = s32[64,8732]{1,0} parameter(0)
|
||||||
|
sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare,
|
||||||
|
is_stable=true
|
||||||
|
ROOT neg = s32[64,8732]{1,0} negate(sort)
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
StableSortExpander stabilizer;
|
||||||
|
EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, GmockMatch(m::Negate(m::GetTupleElement(
|
||||||
|
m::Sort(m::Parameter(0), m::Iota()), 0))));
|
||||||
|
CheckComputationHasTieBreaker(
|
||||||
|
root->operand(0)->operand(0)->to_apply()->root_instruction(),
|
||||||
|
/*iota_parameter=*/1);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace xla
|
@ -1072,7 +1072,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
|
|||||||
auto keys = builder.AddInstruction(
|
auto keys = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, keys_shape, "keys"));
|
HloInstruction::CreateParameter(0, keys_shape, "keys"));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto* sort, MakeSortHlo(keys_shape, {keys}, 0, &builder, module_.get()));
|
auto* sort, MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false,
|
||||||
|
&builder, module_.get()));
|
||||||
|
|
||||||
computation_ = module_->AddEntryComputation(builder.Build());
|
computation_ = module_->AddEntryComputation(builder.Build());
|
||||||
RunAnalysis();
|
RunAnalysis();
|
||||||
@ -1094,7 +1095,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
|
|||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto* sort,
|
auto* sort,
|
||||||
MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
|
MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
|
||||||
{keys, values}, 0, &builder, module_.get()));
|
{keys, values}, 0, /*is_stable=*/false, &builder,
|
||||||
|
module_.get()));
|
||||||
|
|
||||||
computation_ = module_->AddEntryComputation(builder.Build());
|
computation_ = module_->AddEntryComputation(builder.Build());
|
||||||
RunAnalysis();
|
RunAnalysis();
|
||||||
|
@ -1146,7 +1146,7 @@ xla_test(
|
|||||||
xla_test(
|
xla_test(
|
||||||
name = "reduce_test",
|
name = "reduce_test",
|
||||||
srcs = ["reduce_test.cc"],
|
srcs = ["reduce_test.cc"],
|
||||||
shard_count = 40,
|
shard_count = 31,
|
||||||
tags = [
|
tags = [
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
|
@ -1188,6 +1188,8 @@ std::vector<EinsumParamType> GetEinsumTestCases() {
|
|||||||
p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"},
|
p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"},
|
||||||
p{v{5, 6}, v{6, 7}, "ab,cd->dcba"},
|
p{v{5, 6}, v{6, 7}, "ab,cd->dcba"},
|
||||||
p{v{6}, v{6, 7}, "b,bc->c"},
|
p{v{6}, v{6, 7}, "b,bc->c"},
|
||||||
|
p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc->ab"},
|
||||||
|
p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba->ca"},
|
||||||
p{v{77}, v{77}, "a,a->a"},
|
p{v{77}, v{77}, "a,a->a"},
|
||||||
p{v{77}, v{77, 55}, "a,ab->ba"},
|
p{v{77}, v{77, 55}, "a,ab->ba"},
|
||||||
p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"},
|
p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"},
|
||||||
@ -1265,5 +1267,51 @@ ENTRY %test {
|
|||||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
|
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_1) {
|
||||||
|
// Tests for a caching bug in the XLA CPU backend.
|
||||||
|
absl::string_view hlo_string =
|
||||||
|
R"(
|
||||||
|
HloModule CpuTiledDotEmitterCachingBug
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
lhs = f32[20,40] parameter(0)
|
||||||
|
rhs_0 = f32[40,1] parameter(2)
|
||||||
|
rhs_1 = f32[1,40] parameter(1)
|
||||||
|
|
||||||
|
dot_0 = f32[20,1] dot(lhs, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||||
|
dot_1 = f32[20,1] dot(lhs, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
|
||||||
|
|
||||||
|
ROOT result = f32[20,1] divide(dot_0, dot_1)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_2) {
|
||||||
|
// Tests for a caching bug in the XLA CPU backend.
|
||||||
|
absl::string_view hlo_string =
|
||||||
|
R"(
|
||||||
|
HloModule CpuTiledDotEmitterCachingBug
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
lhs_0 = f32[20,40] parameter(0)
|
||||||
|
rhs_0 = f32[40,1] parameter(1)
|
||||||
|
lhs_1 = f32[1,40] parameter(2)
|
||||||
|
rhs_1 = f32[20,40] parameter(3)
|
||||||
|
|
||||||
|
dot_0 = f32[20,1] dot(lhs_0, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||||
|
dot_1 = f32[1,20] dot(lhs_1, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
|
||||||
|
|
||||||
|
dot_0_reshaped = f32[20] reshape(dot_0)
|
||||||
|
dot_1_reshaped = f32[20] reshape(dot_1)
|
||||||
|
|
||||||
|
ROOT result = f32[20] divide(dot_0_reshaped, dot_1_reshaped)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -205,6 +205,17 @@ Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
|
|||||||
return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
|
return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
|
||||||
|
std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
|
||||||
|
int64 num_replicas) {
|
||||||
|
HloRunner::ReplicatedExecuteOptions options;
|
||||||
|
options.num_replicas = num_replicas;
|
||||||
|
for (auto argument : arguments) {
|
||||||
|
options.arguments.push_back(argument);
|
||||||
|
}
|
||||||
|
return test_runner_.ExecuteReplicated(std::move(module), options);
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
|
StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
|
||||||
const HloModule& test_module,
|
const HloModule& test_module,
|
||||||
const std::function<void(HloModule*)>& reference_preprocessor) {
|
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||||
|
@ -173,6 +173,11 @@ class HloTestBase : public ::testing::Test {
|
|||||||
Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module,
|
Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module,
|
||||||
absl::Span<Literal* const> arguments);
|
absl::Span<Literal* const> arguments);
|
||||||
|
|
||||||
|
// Executes the given module on multiple replicas.
|
||||||
|
StatusOr<std::vector<Literal>> ExecuteReplicated(
|
||||||
|
std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
|
||||||
|
int64 num_replicas);
|
||||||
|
|
||||||
// Executes the given hlo module on two backends and compares results.
|
// Executes the given hlo module on two backends and compares results.
|
||||||
//
|
//
|
||||||
// 'arguments': the input of the hlo module.
|
// 'arguments': the input of the hlo module.
|
||||||
|
@ -30,6 +30,11 @@ def xla_proto_library(name, srcs = [], deps = [], visibility = None, testonly =
|
|||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def xla_py_proto_library(**kwargs):
|
||||||
|
# Note: we don't currently define a proto library target for Python in OSS.
|
||||||
|
_ignore = kwargs
|
||||||
|
pass
|
||||||
|
|
||||||
def xla_py_grpc_library(**kwargs):
|
def xla_py_grpc_library(**kwargs):
|
||||||
# Note: we don't currently define any special targets for Python GRPC in OSS.
|
# Note: we don't currently define any special targets for Python GRPC in OSS.
|
||||||
_ignore = kwargs
|
_ignore = kwargs
|
||||||
|
@ -122,6 +122,17 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
|
|||||||
.HostMemory("literal"),
|
.HostMemory("literal"),
|
||||||
XRTReadLiteralOp<true, XRTGenericDeviceAccessor>);
|
XRTReadLiteralOp<true, XRTGenericDeviceAccessor>);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor")
|
||||||
|
.Device(DEVICE_XLA_GPU)
|
||||||
|
.HostMemory("handles")
|
||||||
|
.HostMemory("tensors"),
|
||||||
|
XRTReadToTensorOp<XRTGenericDeviceAccessor>);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor")
|
||||||
|
.Device(DEVICE_XLA_CPU)
|
||||||
|
.HostMemory("handles")
|
||||||
|
.HostMemory("tensors"),
|
||||||
|
XRTReadToTensorOp<XRTGenericDeviceAccessor>);
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
|
REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
|
||||||
.Device(DEVICE_XLA_GPU)
|
.Device(DEVICE_XLA_GPU)
|
||||||
.HostMemory("handle"),
|
.HostMemory("handle"),
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/compiler/xla/layout_util.h"
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
@ -40,6 +41,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/refcount.h"
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
@ -215,27 +217,29 @@ class XRTAllocateFromTensorOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple));
|
||||||
|
std::vector<int64> minor_to_major;
|
||||||
if (ctx->HasAttr("layouts")) {
|
if (ctx->HasAttr("layouts")) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major));
|
||||||
}
|
}
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, tf_shapes_.size() == dtypes_.size(),
|
ctx, tf_shapes_.size() == dtypes_.size(),
|
||||||
errors::InvalidArgument("shapes and dtypes must be the same length"));
|
errors::InvalidArgument("shapes and dtypes must be the same length"));
|
||||||
std::vector<xla::Shape> xla_shapes;
|
std::vector<xla::Shape> xla_shapes;
|
||||||
|
xla_shapes.reserve(tf_shapes_.size());
|
||||||
for (int i = 0; i < tf_shapes_.size(); i++) {
|
for (int i = 0; i < tf_shapes_.size(); i++) {
|
||||||
xla::Shape xla_shape;
|
xla::Shape xla_shape;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape));
|
ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape));
|
||||||
xla_shapes.push_back(xla_shape);
|
xla_shapes.push_back(std::move(xla_shape));
|
||||||
}
|
}
|
||||||
if (xla_shapes.size() > 1 || make_tuple) {
|
if (xla_shapes.size() > 1 || make_tuple) {
|
||||||
shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes);
|
shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes);
|
||||||
} else {
|
} else {
|
||||||
shape_.Swap(&xla_shapes.front());
|
shape_.Swap(&xla_shapes.front());
|
||||||
}
|
}
|
||||||
if (!minor_to_major_.empty()) {
|
if (!minor_to_major.empty()) {
|
||||||
xla::Shape shape_with_layouts;
|
xla::Shape shape_with_layouts;
|
||||||
OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major_,
|
OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major,
|
||||||
/*layout_func=*/nullptr,
|
/*layout_func=*/nullptr,
|
||||||
&shape_with_layouts));
|
&shape_with_layouts));
|
||||||
shape_.Swap(&shape_with_layouts);
|
shape_.Swap(&shape_with_layouts);
|
||||||
@ -304,7 +308,6 @@ class XRTAllocateFromTensorOp : public OpKernel {
|
|||||||
private:
|
private:
|
||||||
std::vector<TensorShape> tf_shapes_;
|
std::vector<TensorShape> tf_shapes_;
|
||||||
DataTypeVector dtypes_;
|
DataTypeVector dtypes_;
|
||||||
std::vector<int64> minor_to_major_;
|
|
||||||
xla::Shape shape_;
|
xla::Shape shape_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -487,7 +490,7 @@ class XRTReadLiteralOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
|
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
|
||||||
ctx, allocation->device_ordinal(), &device_ref));
|
ctx, allocation->device_ordinal(), &device_ref));
|
||||||
|
|
||||||
xla::Literal literal;
|
xla::Literal literal(allocation->on_host_shape());
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, allocation->ToLiteral(device_ref.backend(),
|
ctx, allocation->ToLiteral(device_ref.backend(),
|
||||||
device_ref.device_ordinal(), &literal));
|
device_ref.device_ordinal(), &literal));
|
||||||
@ -499,6 +502,96 @@ class XRTReadLiteralOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Op that reads a device-resident tuple to host memory and returns it as a
|
||||||
|
// literal.
|
||||||
|
template <class DeviceAccessor>
|
||||||
|
class XRTReadToTensorOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit XRTReadToTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("release_handles", &discard_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
|
||||||
|
}
|
||||||
|
~XRTReadToTensorOp() override = default;
|
||||||
|
XRTReadToTensorOp(const XRTReadToTensorOp&) = delete;
|
||||||
|
XRTReadToTensorOp& operator=(const XRTReadToTensorOp&) = delete;
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
VLOG(1) << "XRTReadToTensorOp::Compute";
|
||||||
|
|
||||||
|
const Tensor& handle_tensor = ctx->input(0);
|
||||||
|
// TODO(phawkins,dlibenzi): accept multiple handles (i.e., vectors, not
|
||||||
|
// just scalars.)
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
|
||||||
|
errors::Internal("computation input should be an int64 scalar"));
|
||||||
|
int64 allocation_handle = handle_tensor.scalar<int64>()();
|
||||||
|
|
||||||
|
ResourceMgr* rm;
|
||||||
|
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
|
||||||
|
|
||||||
|
XRTTupleAllocation* allocation;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
|
||||||
|
core::ScopedUnref allocation_unref(allocation);
|
||||||
|
|
||||||
|
if (discard_) {
|
||||||
|
VLOG(2) << "Releasing handle " << allocation_handle;
|
||||||
|
OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
|
||||||
|
rm, allocation_handle));
|
||||||
|
}
|
||||||
|
|
||||||
|
// We are guaranteed that the underlying device object won't be deleted out
|
||||||
|
// from under us, while the ScopedRef is live.
|
||||||
|
class DeviceAccessor::ScopedRef device_ref;
|
||||||
|
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
|
||||||
|
ctx, allocation->device_ordinal(), &device_ref));
|
||||||
|
|
||||||
|
xla::Shape shape = allocation->on_host_shape();
|
||||||
|
int output = 0;
|
||||||
|
Status status = xla::ShapeUtil::ForEachMutableSubshapeWithStatus(
|
||||||
|
&shape,
|
||||||
|
[&](xla::Shape* subshape, const xla::ShapeIndex& index) -> Status {
|
||||||
|
if (subshape->IsTuple()) return Status::OK();
|
||||||
|
|
||||||
|
xla::PrimitiveType xla_type;
|
||||||
|
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(
|
||||||
|
ctx->expected_output_dtype(output), &xla_type));
|
||||||
|
if (xla_type != subshape->element_type()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Type mismatch between buffer type (", subshape->ToString(),
|
||||||
|
") and tensor type (",
|
||||||
|
DataTypeString(ctx->expected_output_dtype(output)),
|
||||||
|
") for output tensor ", output);
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShape output_shape;
|
||||||
|
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(*subshape, &output_shape));
|
||||||
|
|
||||||
|
Tensor* output_tensor;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ctx->allocate_output(output, output_shape, &output_tensor));
|
||||||
|
|
||||||
|
XRTTupleAllocation* sub;
|
||||||
|
TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
|
||||||
|
allocation, index, &sub, /*alias_parent_allocation=*/true));
|
||||||
|
core::ScopedUnref sub_unref(sub);
|
||||||
|
|
||||||
|
xla::MutableBorrowingLiteral literal;
|
||||||
|
TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral(
|
||||||
|
xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor,
|
||||||
|
&literal));
|
||||||
|
TF_RETURN_IF_ERROR(sub->ToLiteral(
|
||||||
|
device_ref.backend(), device_ref.device_ordinal(), &literal));
|
||||||
|
|
||||||
|
++output;
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
OP_REQUIRES_OK(ctx, status);
|
||||||
|
}
|
||||||
|
bool discard_;
|
||||||
|
DataTypeVector dtypes_;
|
||||||
|
};
|
||||||
|
|
||||||
// Op that writes a new literal value into device-resident memory.
|
// Op that writes a new literal value into device-resident memory.
|
||||||
template <class DeviceAccessor>
|
template <class DeviceAccessor>
|
||||||
class XRTWriteLiteralOp : public OpKernel {
|
class XRTWriteLiteralOp : public OpKernel {
|
||||||
|
@ -151,6 +151,27 @@ releases the handle.
|
|||||||
'literal' is a serialized xla::LiteralProto proto.
|
'literal' is a serialized xla::LiteralProto proto.
|
||||||
)");
|
)");
|
||||||
|
|
||||||
|
REGISTER_OP("XRTReadToTensor")
|
||||||
|
.Input("handles: int64")
|
||||||
|
.Attr("release_handles: bool = False")
|
||||||
|
.Attr("dtypes: list(type)")
|
||||||
|
.Output("tensors: dtypes")
|
||||||
|
.SetShapeFn(tensorflow::shape_inference::UnknownShape)
|
||||||
|
.Doc(
|
||||||
|
R"(
|
||||||
|
Copies allocated values from device memory and returns them as zero or more
|
||||||
|
Tensors. If a handle refers to a non-tuple buffer, a single tensor is returned.
|
||||||
|
In general, the tensors returned for a handle correspond to an in-order traversal
|
||||||
|
of a the tuple-tree value referenced by the handle.
|
||||||
|
|
||||||
|
'handles' contains ids returned from Ops that produced on-device allocations.
|
||||||
|
At present, only a single (scalar) handle is supported.
|
||||||
|
'dtypes' are the expected types for each `Tensor` to be returned. If the
|
||||||
|
expected and actual tensor types do not match, an error is returned.
|
||||||
|
'release_handles': if True, `handles` are released.
|
||||||
|
'tensors' are the output Tensors.
|
||||||
|
)");
|
||||||
|
|
||||||
REGISTER_OP("XRTReleaseAllocationHandle")
|
REGISTER_OP("XRTReleaseAllocationHandle")
|
||||||
.Input("handle: int64")
|
.Input("handle: int64")
|
||||||
.SetShapeFn(tensorflow::shape_inference::NoOutputs)
|
.SetShapeFn(tensorflow::shape_inference::NoOutputs)
|
||||||
|
@ -220,7 +220,7 @@ XRTTupleAllocation::~XRTTupleAllocation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
|
Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
|
||||||
xla::Literal* literal) {
|
xla::MutableLiteralBase* literal) {
|
||||||
auto transfer_manager = backend->transfer_manager();
|
auto transfer_manager = backend->transfer_manager();
|
||||||
TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
|
TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
|
||||||
|
|
||||||
@ -234,9 +234,8 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
|
|||||||
" has been released");
|
" has been released");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice(
|
return transfer_manager->TransferLiteralFromDevice(stream.get(),
|
||||||
stream.get(), shaped_buffer));
|
shaped_buffer, *literal);
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend,
|
Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend,
|
||||||
|
@ -147,7 +147,7 @@ class XRTTupleAllocation : public ResourceBase {
|
|||||||
|
|
||||||
// Copies the allocation from device to host and returns it in literal.
|
// Copies the allocation from device to host and returns it in literal.
|
||||||
Status ToLiteral(xla::Backend* backend, int device_ordinal,
|
Status ToLiteral(xla::Backend* backend, int device_ordinal,
|
||||||
xla::Literal* literal);
|
xla::MutableLiteralBase* literal);
|
||||||
|
|
||||||
// Write a new literal value to the allocation.
|
// Write a new literal value to the allocation.
|
||||||
Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal);
|
Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal);
|
||||||
|
@ -218,7 +218,6 @@ cc_library(
|
|||||||
"//tensorflow/contrib/tensor_forest:stats_ops_op_lib",
|
"//tensorflow/contrib/tensor_forest:stats_ops_op_lib",
|
||||||
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
|
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
|
||||||
"//tensorflow/contrib/text:all_ops",
|
"//tensorflow/contrib/text:all_ops",
|
||||||
"//tensorflow/contrib/tpu:all_ops",
|
|
||||||
] + select({
|
] + select({
|
||||||
"//tensorflow:android": [],
|
"//tensorflow:android": [],
|
||||||
"//tensorflow:ios": [],
|
"//tensorflow:ios": [],
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user