Merge branch 'master' into upstream-staging-norocm-tag-1

This commit is contained in:
Gunhan Gulsoy 2019-02-14 10:15:04 -08:00 committed by GitHub
commit a5eb965967
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
425 changed files with 34994 additions and 20327 deletions

View File

@ -26,14 +26,28 @@ import sys as _sys
# API IMPORTS PLACEHOLDER # API IMPORTS PLACEHOLDER
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
# We're using bitwise, but there's nothing special about that.
_API_MODULE = bitwise # pylint: disable=undefined-variable
_current_module = _sys.modules[__name__]
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
if not hasattr(_current_module, '__path__'):
__path__ = [_tf_api_dir]
elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from tensorflow.python.tools import component_api_helper as _component_api_helper from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorboard.summary._tf.summary'),
error_msg="Limited tf.summary API due to missing TensorBoard installation")
_component_api_helper.package_hook( _component_api_helper.package_hook(
parent_package_str=__name__, parent_package_str=__name__,
child_package_str=( child_package_str=(
'tensorflow_estimator.python.estimator.api._v2.estimator')) 'tensorflow_estimator.python.estimator.api._v2.estimator'))
_current_module = _sys.modules[__name__]
if not hasattr(_current_module, 'estimator'): if not hasattr(_current_module, 'estimator'):
_component_api_helper.package_hook( _component_api_helper.package_hook(
parent_package_str=__name__, parent_package_str=__name__,
@ -42,14 +56,6 @@ if not hasattr(_current_module, 'estimator'):
_component_api_helper.package_hook( _component_api_helper.package_hook(
parent_package_str=__name__, parent_package_str=__name__,
child_package_str=('tensorflow.python.keras.api._v2.keras')) child_package_str=('tensorflow.python.keras.api._v2.keras'))
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
# We're using bitwise, but there's nothing special about that.
_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable
if not hasattr(_current_module, '__path__'):
__path__ = [_tf_api_dir]
elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# Enable TF2 behaviors # Enable TF2 behaviors
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top

View File

@ -70,7 +70,7 @@ _API_MODULE = app # pylint: disable=undefined-variable
# Make sure directory containing top level submodules is in # Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works. # the __path__ so that "from tensorflow.foo import bar" works.
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) # pylint: disable=undefined-variable _tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
if not hasattr(_current_module, '__path__'): if not hasattr(_current_module, '__path__'):
__path__ = [_tf_api_dir] __path__ = [_tf_api_dir]
elif _tf_api_dir not in __path__: elif _tf_api_dir not in __path__:

View File

@ -150,6 +150,7 @@ cc_library_with_android_deps(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
], ],
) )
@ -586,6 +587,25 @@ tf_gen_op_wrappers_cc(
pkg = "//tensorflow/core", pkg = "//tensorflow/core",
) )
tf_gen_op_wrappers_cc(
name = "tpu_ops",
include_internal_ops = 1,
op_lib_names = [
"tpu_configuration_ops",
"tpu_cross_replica_ops",
"tpu_embedding_ops",
"tpu_functional_ops",
"tpu_heartbeat_ops",
"tpu_host_compute_ops",
"tpu_infeed_ops",
"tpu_outfeed_ops",
"tpu_ordinal_selector_ops",
"tpu_replication_ops",
],
pkg = "//tensorflow/core",
visibility = ["//tensorflow:internal"],
)
cc_library_with_android_deps( cc_library_with_android_deps(
name = "cc_op_gen_main", name = "cc_op_gen_main",
srcs = [ srcs = [

View File

@ -81,6 +81,7 @@ cc_library(
] + if_not_mobile([ ] + if_not_mobile([
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow", "//tensorflow/core:tensorflow",
]) + if_android([ ]) + if_android([

View File

@ -22,11 +22,16 @@ import os as _os
import sys as _sys import sys as _sys
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
# API IMPORTS PLACEHOLDER # API IMPORTS PLACEHOLDER
from tensorflow.python.tools import component_api_helper as _component_api_helper from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorboard.summary._tf.summary'),
error_msg=(
"Limited tf.compat.v2.summary API due to missing TensorBoard "
"installation"))
_component_api_helper.package_hook( _component_api_helper.package_hook(
parent_package_str=__name__, parent_package_str=__name__,
child_package_str=( child_package_str=(

View File

@ -307,22 +307,6 @@ REGISTER_OP("XlaHostCompute")
.Attr("shapes: list(shape) >= 0") .Attr("shapes: list(shape) >= 0")
.SetShapeFn(::tensorflow::shape_inference::UnknownShape); .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
REGISTER_OP("_XlaSendFromHost")
.Input("inputs: Tinputs")
.Input("dynamic_key: string")
.Attr("Tinputs: list(type) >= 0")
.Attr("key: string")
.Attr("device_ordinal: int")
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
REGISTER_OP("_XlaRecvAtHost")
.Input("dynamic_key: string")
.Output("outputs: Toutputs")
.Attr("Toutputs: list(type) >= 0")
.Attr("key: string")
.Attr("device_ordinal: int")
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
REGISTER_OP("InputTest") REGISTER_OP("InputTest")
.Output("o: float") .Output("o: float")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {

View File

@ -200,7 +200,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
auto serialized = absl::make_unique<char[]>(size); auto serialized = absl::make_unique<char[]>(size);
TF_RET_CHECK(SerializeToBufferDeterministic(gdef, serialized.get(), size)); TF_RET_CHECK(SerializeToBufferDeterministic(gdef, serialized.get(), size));
uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size)); uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size));
LOG(INFO) << "Subgraph fingerprint:" << fingerprint; VLOG(1) << "Subgraph fingerprint:" << fingerprint;
call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint));
return Status::OK(); return Status::OK();
} }

View File

@ -250,6 +250,29 @@ tf_xla_py_test(
], ],
) )
tf_xla_py_test(
name = "self_adjoint_eig_op_test",
size = "medium",
srcs = ["self_adjoint_eig_op_test.py"],
# TODO(kuny): remove it after b/124377352 is fixed.
disabled_backends = [
"cpu",
"gpu",
"cpu_ondemand",
],
tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:map_fn",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
"@absl_py//absl/testing:parameterized",
],
)
tf_xla_py_test( tf_xla_py_test(
name = "matrix_triangular_solve_op_test", name = "matrix_triangular_solve_op_test",
size = "small", size = "small",

View File

@ -0,0 +1,62 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.self_adjoint_eig."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.platform import test
class SelfAdjointEigOpTest(xla_test.XLATestCase, parameterized.TestCase):
def _test(self, dtype, shape):
np.random.seed(1)
x_np = np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
x_np = x_np + np.swapaxes(x_np, -1, -2)
n = shape[-1]
e_np, _ = np.linalg.eigh(x_np)
with self.cached_session() as sess:
x_tf = array_ops.placeholder(dtype)
with self.test_scope():
e, v = linalg_ops.self_adjoint_eig(x_tf)
e_val, v_val = sess.run([e, v], feed_dict={x_tf: x_np})
v_diff = np.matmul(v_val, np.swapaxes(v_val, -1, -2)) - np.eye(n)
self.assertAlmostEqual(np.mean(v_diff**2), 0.0, delta=1e-6)
self.assertAlmostEqual(np.mean((e_val - e_np)**2), 0.0, delta=1e-6)
SIZES = [1, 2, 5, 10, 32]
DTYPES = [np.float32]
PARAMS = itertools.product(SIZES, DTYPES)
@parameterized.parameters(*PARAMS)
def testSelfAdjointEig(self, n, dtype):
for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10):
self._test(dtype, batch_dims + (n, n))
if __name__ == "__main__":
test.main()

View File

@ -102,7 +102,7 @@ class ListOpsTest(xla_test.XLATestCase):
_, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
with self.assertRaisesRegexp(errors.InvalidArgumentError, with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Set the max number of elements"): "Set the max number of elements"):
self.assertEqual(sess.run(e), 1.0 * np.ones((7, 15))) self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15)))
def testEmptyTensorListMax(self): def testEmptyTensorListMax(self):
with self.cached_session() as sess, self.test_scope(): with self.cached_session() as sess, self.test_scope():
@ -136,6 +136,17 @@ class ListOpsTest(xla_test.XLATestCase):
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [3.0, 2.0]) self.assertAllEqual(t, [3.0, 2.0])
def testSetDoesNotUpdatePushIndex(self):
with self.cached_session(), self.test_scope():
l = list_ops.empty_tensor_list(
element_shape=[], element_dtype=dtypes.float32, max_num_elements=2)
# SetItem should not change the push index.
l = list_ops.tensor_list_set_item(l, 1, 3.)
l = list_ops.tensor_list_push_back(l, 5.)
l = list_ops.tensor_list_push_back(l, 7.)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [5., 7.])
def testGetSetReserved(self): def testGetSetReserved(self):
with self.cached_session(), self.test_scope(): with self.cached_session(), self.test_scope():
l = list_ops.tensor_list_reserve( l = list_ops.tensor_list_reserve(
@ -146,6 +157,25 @@ class ListOpsTest(xla_test.XLATestCase):
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [3.0, 0.0]) self.assertAllEqual(t, [3.0, 0.0])
def testSetStackReservedUnknownElementShape(self):
with self.cached_session(), self.test_scope():
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=None, num_elements=2)
l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0])
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]])
def testPushInEmptyListWithUnknownElementShape(self):
with self.cached_session(), self.test_scope():
l = list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_shape=None, max_num_elements=2)
l = list_ops.tensor_list_push_back(l, [3.0, 4.0])
# Pushing an element with a different shape should raise an error.
with self.assertRaisesRegexp(errors.InvalidArgumentError, "Shape"):
l = list_ops.tensor_list_push_back(l, 5.)
self.evaluate(
list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
def testGetSetReservedNonScalar(self): def testGetSetReservedNonScalar(self):
with self.cached_session() as sess, self.test_scope(): with self.cached_session() as sess, self.test_scope():
l = list_ops.tensor_list_reserve( l = list_ops.tensor_list_reserve(

View File

@ -72,6 +72,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
output = op(pinp) output = op(pinp)
result = session.run(output, {pinp: inp}) result = session.run(output, {pinp: inp})
if equality_test is None: if equality_test is None:
self.assertEqual(output.dtype, expected.dtype)
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03) result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
else: else:
@ -260,7 +261,8 @@ class UnaryOpsTest(xla_test.XLATestCase):
self._assertOpOutputMatchesExpected( self._assertOpOutputMatchesExpected(
math_ops.log1p, math_ops.log1p,
np.array([[1e-14, 1e-15, 0.6]], dtype=dtype), np.array([[1e-14, 1e-15, 0.6]], dtype=dtype),
expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)), expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]],
dtype=dtype)).astype(dtype),
rtol=1e-4, rtol=1e-4,
atol=1e-6) atol=1e-6)
@ -710,7 +712,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
self._assertOpOutputMatchesExpected( self._assertOpOutputMatchesExpected(
math_ops.abs, math_ops.abs,
np.array([[2, -1]], dtype=dtype), np.array([[2, -1]], dtype=dtype),
expected=np.array([[2, 1]], dtype=dtype)) expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype))
self._assertOpOutputMatchesExpected( self._assertOpOutputMatchesExpected(
math_ops.negative, math_ops.negative,
@ -880,6 +882,17 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array([[-1], [1], [4]], dtype=dtype), np.array([[-1], [1], [4]], dtype=dtype),
expected=np.int32(3)) expected=np.int32(3))
def testSizeWithInt64OutType(self):
def size_op(x):
return array_ops.size_internal(x, optimize=False, out_type=np.int64)
for dtype in self.numeric_types:
self._assertOpOutputMatchesExpected(
size_op,
np.array([[-1], [1], [4]], dtype=dtype),
expected=np.int64(3))
def testUnpack(self): def testUnpack(self):
self._assertOpOutputMatchesExpected( self._assertOpOutputMatchesExpected(
array_ops.unstack, array_ops.unstack,
@ -989,7 +1002,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
def _assertSoftplusMatchesExpected(self, features, dtype): def _assertSoftplusMatchesExpected(self, features, dtype):
features = np.array(features, dtype=dtype) features = np.array(features, dtype=dtype)
zero = np.asarray(0).astype(dtype) zero = np.asarray(0).astype(dtype)
expected = np.logaddexp(zero, features) expected = np.logaddexp(zero, features).astype(dtype)
self._assertOpOutputMatchesExpected( self._assertOpOutputMatchesExpected(
nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6) nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6)

View File

@ -171,13 +171,11 @@ tf_cuda_library(
name = "trt_resources", name = "trt_resources",
srcs = [ srcs = [
"utils/trt_int8_calibrator.cc", "utils/trt_int8_calibrator.cc",
"utils/trt_resource_manager.cc",
"utils/trt_resources.cc", "utils/trt_resources.cc",
], ],
hdrs = [ hdrs = [
"utils/trt_int8_calibrator.h", "utils/trt_int8_calibrator.h",
"utils/trt_lru_cache.h", "utils/trt_lru_cache.h",
"utils/trt_resource_manager.h",
"utils/trt_resources.h", "utils/trt_resources.h",
], ],
deps = [ deps = [
@ -266,7 +264,6 @@ tf_cuda_library(
"//tensorflow/core:framework_lite", "//tensorflow/core:framework_lite",
"//tensorflow/core:gpu_runtime", "//tensorflow/core:gpu_runtime",
"//tensorflow/core:graph", "//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:devices", "//tensorflow/core/grappler:devices",
@ -362,11 +359,12 @@ cc_library(
], ],
) )
tf_cc_test( tf_cuda_cc_test(
name = "segment_test", name = "segment_test",
size = "small", size = "small",
srcs = ["segment/segment_test.cc"], srcs = ["segment/segment_test.cc"],
tags = [ tags = [
"no_cuda_on_cpu_tap",
"no_windows", "no_windows",
"nomac", "nomac",
], ],
@ -432,7 +430,7 @@ cc_library(
copts = tf_copts(), copts = tf_copts(),
deps = [ deps = [
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing",
], ],
) )
@ -441,7 +439,7 @@ cc_library(
srcs = ["utils/test_utils.cc"], srcs = ["utils/test_utils.cc"],
hdrs = ["utils/test_utils.h"], hdrs = ["utils/test_utils.h"],
deps = [ deps = [
"//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing",
"@com_googlesource_code_re2//:re2", "@com_googlesource_code_re2//:re2",
], ],
) )

View File

@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/segment/segment.h" #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
@ -106,6 +105,7 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) {
"ExpandDims", "ExpandDims",
"FusedBatchNorm", "FusedBatchNorm",
"FusedBatchNormV2", "FusedBatchNormV2",
"GatherV2",
"Identity", "Identity",
"LeakyRelu", "LeakyRelu",
"Log", "Log",
@ -190,55 +190,6 @@ tensorflow::Status BuildNodeMap(
} // namespace } // namespace
// Function to get calibration from ResourceMgr and put them into nodedef.
tensorflow::Status ConvertCalibGraphToInferGraph(
const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph,
bool is_dyn_op) {
LOG(INFO) << "Starting Calib Conversion";
*infer_graph = graph_def;
auto trt_rm = TRTResourceManager::instance();
auto calib_rm = trt_rm->getManager("TRTCalibration");
int num_nodes = infer_graph->node_size();
if (!is_dyn_op) {
LOG(WARNING) << "Construction of static int8 engine is not implemented "
"yet!. Dynamic engine will be constructed";
}
for (int i = 0; i < num_nodes; ++i) {
auto n = infer_graph->mutable_node(i);
if (n->op() == "TRTEngineOp") {
VLOG(1) << "Processing " << n->name();
const string& container_name = n->attr().at("segment_funcdef_name").s();
TRTCalibrationResource* cres = nullptr;
auto status = calib_rm->Lookup(container_name, "Calibrator", &cres);
if (!status.ok()) {
LOG(ERROR) << "Could not get Calibration information. Did you run with "
"calibration data?";
return tensorflow::errors::FailedPrecondition(
"Need to run graph with calibration data first!");
}
tensorflow::core::ScopedUnref calib_sc(cres);
if (cres->calibrator_) {
cres->calibrator_->waitAndSetDone();
cres->thr_->join();
const auto& calibration_table =
cres->calibrator_->getCalibrationTableAsString();
if (calibration_table.empty()) {
LOG(ERROR) << "Calibration table is empty";
return tensorflow::errors::Unknown(
"Calibration table is missing. This shouldn't have happened!");
}
n->mutable_attr()->at("calibration_data").set_s(calibration_table);
} else {
LOG(ERROR) << "Can't get TRTCalibrator from resource manager!";
return tensorflow::errors::Unknown(
"Can't get TRTCalibrator from resource manager!");
}
TF_RETURN_IF_ERROR(calib_rm->Cleanup(container_name));
}
}
return tensorflow::Status::OK();
}
tensorflow::Status ConvertGraphDefToTensorRT( tensorflow::Status ConvertGraphDefToTensorRT(
const tensorflow::GraphDef& graph_def, const tensorflow::GraphDef& graph_def,
const std::vector<string>& output_names, size_t max_batch_size, const std::vector<string>& output_names, size_t max_batch_size,

View File

@ -85,12 +85,6 @@ struct ConversionParams {
std::vector<int> cached_engine_batches; // list of cached engines std::vector<int> cached_engine_batches; // list of cached engines
}; };
// This method extracts calibration information from the resource managers
// and puts them in to engine nodedefs.
tensorflow::Status ConvertCalibGraphToInferGraph(
const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def,
bool is_dyn_op);
// - max_batch_size: maximum batch size which can be used for inference for // - max_batch_size: maximum batch size which can be used for inference for
// optimization targets inference run with max batch size. // optimization targets inference run with max batch size.
// - max_workspace_size_bytes: The upper bound of memory allowance for engine // - max_workspace_size_bytes: The upper bound of memory allowance for engine

View File

@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def.pb.h" // NOLINT
#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_builder.h"
@ -379,6 +378,32 @@ tensorflow::Status CreateBroadcastableScalarConstant(
return Status::OK(); return Status::OK();
} }
// Convert an axis from TF format to TRT format while validating. TF format
// includes the batch dimension, while TRT does not. TF can also use negative
// indices.
// TODO(tmorris): Use this method in more ops.
tensorflow::Status ConvertAxis(int tf_axis, int trt_nb_dims,
absl::string_view node_name, int* trt_axis) {
const int tf_nb_dims = trt_nb_dims + 1;
// Check bounds.
if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) {
return tensorflow::errors::InvalidArgument(
"Axis value of ", tf_axis, " is out of bounds, must be in range [",
-tf_nb_dims, ", ", tf_nb_dims, "), at ", node_name);
}
// Make negative axis positive.
if (tf_axis < 0) tf_axis += tf_nb_dims;
// Don't allow axis to be the batch dimension.
if (tf_axis == 0) {
return tensorflow::errors::Unimplemented(
"TensorRT does not allow manipulation of the batch dimension, at ",
node_name);
}
// Remove batch dimension.
*trt_axis = tf_axis - 1;
return Status::OK();
}
inline bool DimsEqual(const nvinfer1::Dims& dim_l, inline bool DimsEqual(const nvinfer1::Dims& dim_l,
const nvinfer1::Dims& dim_r) { const nvinfer1::Dims& dim_r) {
if (dim_l.nbDims != dim_r.nbDims) { if (dim_l.nbDims != dim_r.nbDims) {
@ -3413,6 +3438,29 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) {
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
tensorflow::Status ConvertGather(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(
*params, {{"params", false}, {"indices", false}, {"axis", true}}));
absl::Span<const int> axis = inputs.at(2).weights().GetSpan<int>();
if (axis.size() != 1) {
return tensorflow::errors::InvalidArgument(
"Axis for GatherV2 must be a scalar, at ", node_def.name());
}
int trt_axis = 0;
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims,
node_def.name(), &trt_axis));
if (params->validation_only) return Status::OK();
nvinfer1::IGatherLayer* layer = params->converter->network()->addGather(
*const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor()),
*const_cast<nvinfer1::ITensor*>(inputs.at(1).tensor()), trt_axis);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
return Status::OK();
}
tensorflow::Status ConvertMatMulHelper(OpConverterParams* params, tensorflow::Status ConvertMatMulHelper(OpConverterParams* params,
TRT_TensorOrWeights tensor_input, TRT_TensorOrWeights tensor_input,
TRT_ShapedWeights weights_raw, TRT_ShapedWeights weights_raw,
@ -3643,6 +3691,7 @@ static void RegisterValidatableOpConverters(
(*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput; (*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput;
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
(*registration)["ExpandDims"] = ConvertExpandDims; (*registration)["ExpandDims"] = ConvertExpandDims;
(*registration)["GatherV2"] = ConvertGather;
(*registration)["LeakyRelu"] = ConvertLeakyRelu; (*registration)["LeakyRelu"] = ConvertLeakyRelu;
(*registration)["MatMul"] = ConvertMatMul; (*registration)["MatMul"] = ConvertMatMul;
(*registration)["Pad"] = ConvertPad; (*registration)["Pad"] = ConvertPad;

View File

@ -190,6 +190,11 @@ class TRT_ShapedWeights {
string DebugString() const; string DebugString() const;
template <typename T>
absl::Span<const T> GetSpan() const {
return absl::Span<const T>(tensor_.flat<T>().data(), count());
}
// TODO(aaroey): make these private. // TODO(aaroey): make these private.
nvinfer1::Dims shape_; // Note: shape.type[] is not used. nvinfer1::Dims shape_; // Note: shape.type[] is not used.
tensorflow::DataType type_; tensorflow::DataType type_;

View File

@ -3129,6 +3129,126 @@ TEST_F(OpConverterTest, ConvertTopK) {
} }
} }
template <DataType dtype>
void TestConvertGather(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType;
// Get the NodeDef for GatherV2.
Scope s = Scope::NewRootScope();
auto params = ops::Placeholder(s.WithOpName("params"), dtype);
auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis);
const NodeDef& node_def = gather.operation.node()->def();
struct TestParams {
std::vector<int> params_dims;
std::vector<int> indices_dims;
std::vector<int> indices;
int axis;
std::vector<int> expected_output_dims;
std::vector<int> expected_output;
};
// Input is the same {1, 2, 3, 4, 5, 6} for all cases.
const int kGatherOKCases = 5;
TestParams ok_params[kGatherOKCases] = {
// Vector indices (output is rank(params)).
TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1}, {1, 4}},
TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1}, {2, 5}},
TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1}, {3, 6}},
TestParams{{1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 3}, {3, 1, 2, 6, 4, 5}},
// Higher rank indices (output is rank(params) + rank(indices) - 1).
TestParams{{1, 2, 3}, {1, 1}, {0}, 2, {1, 1, 1, 3}, {1, 2, 3}},
};
// Ok.
for (int i = 0; i < kGatherOKCases; i++) {
test->Reset();
test->AddTestTensor("params", ok_params[i].params_dims, 1,
TfDataTypeToTrt(dtype));
test->AddTestTensor("indices", ok_params[i].indices_dims, 1,
nvinfer1::DataType::kINT32);
test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis});
test->RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_gather", &output));
EXPECT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
output.tensor()->getDimensions());
// Create input in CType and convert expected output to CType.
std::vector<CType> inputs = {CType(1), CType(2), CType(3),
CType(4), CType(5), CType(6)};
std::vector<CType> converted_expected_output(
ok_params[i].expected_output.begin(),
ok_params[i].expected_output.end());
const DataVec input_data{
{"params", test::AsTensor<CType>(inputs)},
{"indices", test::AsTensor<int32>(ok_params[i].indices)}};
DataVec output_data{
{"my_gather",
ConstructTensor<CType>(ok_params[i].expected_output.size())}};
test->BuildAndRun(input_data, &output_data);
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAreArray(converted_expected_output));
}
}
TEST_F(OpConverterTest, ConvertGather) {
{
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_gather", "GatherV2", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"GatherV2 got 0 inputs but expected 3, at my_gather");
}
// Get the NodeDef for GatherV2.
Scope s = Scope::NewRootScope();
auto params = ops::Placeholder(s.WithOpName("params"), DT_FLOAT);
auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis);
const NodeDef& node_def = gather.operation.node()->def();
{
// Axis is a tensor, should fail.
Reset();
AddTestTensor("params", {1, 2, 3});
AddTestTensor("indices", {2});
AddTestTensor("axis", {1});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"The input \"axis\" for GatherV2 must be a constant, at my_gather");
}
{
// Axis is out of bounds, should fail.
Reset();
AddTestTensor("params", {1, 2, 3});
AddTestTensor("indices", {2});
AddTestWeights<int32>("axis", {1}, {4});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Axis value of 4 is out of bounds, must be in "
"range [-4, 4), at my_gather");
}
{
// Axis is batch dimension, should fail.
Reset();
AddTestTensor("params", {1, 2, 3});
AddTestTensor("indices", {2});
AddTestWeights<int32>("axis", {1}, {0});
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
"TensorRT does not allow manipulation of the "
"batch dimension, at my_gather");
}
Reset();
TestConvertGather<DT_FLOAT>(this);
TestConvertGather<DT_HALF>(this);
TestConvertGather<DT_INT32>(this);
}
} // namespace convert } // namespace convert
} // namespace tensorrt } // namespace tensorrt
} // namespace tensorflow } // namespace tensorflow

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/graph_to_functiondef.h"
@ -295,27 +294,6 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
return this->AllocateCalibrationResources(ctx, cr); return this->AllocateCalibrationResources(ctx, cr);
}})); }}));
tensorflow::core::ScopedUnref calib_sc(calib_res); tensorflow::core::ScopedUnref calib_sc(calib_res);
// TODO(aaroey): here we also add the resource to the ResourceMgr singleton.
// This is needed before we migrate all uses of calib_graph_to_infer_graph()
// to the new calibration workflow. After that we'll remove this block.
{
auto deprecated_rm =
TRTResourceManager::instance()->getManager("TRTCalibration");
TRTCalibrationResource* copied_resource = nullptr;
// Check whether the resource exists, and create it if not.
if (deprecated_rm->Lookup(funcdef_name_, "Calibrator", &copied_resource)
.ok()) {
// Do nothing if the resource exists.
copied_resource->Unref();
} else {
copied_resource = calib_res;
// Increase the refcount by 1 then transfer the ownership of that refcount
// to the ResourceMgr singleton.
copied_resource->Ref();
OP_REQUIRES_OK(ctx, deprecated_rm->Create(funcdef_name_, "Calibrator",
copied_resource));
}
}
int num_inputs = ctx->num_inputs(); int num_inputs = ctx->num_inputs();
// Pass input data to calibrator // Pass input data to calibrator
std::unordered_map<string, void*> input_data; std::unordered_map<string, void*> input_data;

View File

@ -38,16 +38,23 @@ def load_trt_ops():
if _trt_ops_so: if _trt_ops_so:
return return
try:
# pylint: disable=g-import-not-at-top,unused-variable
# This registers the TRT ops, it doesn't require loading TRT library.
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op
# pylint: enable=g-import-not-at-top,unused-variable
except ImportError as e:
print("**** Failed to import TF-TRT ops. This is because the binary was "
"not built with CUDA or TensorRT enabled. ****")
raise e
# TODO(laigd): we should load TF-TRT kernels here as well after removing the # TODO(laigd): we should load TF-TRT kernels here as well after removing the
# swig binding. # swig binding.
try: try:
# TODO(lagid): It is not known why these unused imports were introduced. # pylint: disable=g-import-not-at-top
# Investigate and get rid of these, if not required.
# pylint: disable=unused-import,g-import-not-at-top,unused-variable
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op
from tensorflow.python.framework import load_library from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader from tensorflow.python.platform import resource_loader
# pylint: enable=unused-import,g-import-not-at-top,unused-variable # pylint: enable=g-import-not-at-top
_trt_ops_so = load_library.load_op_library( _trt_ops_so = load_library.load_op_library(
resource_loader.get_path_to_datafile("_trt_ops.so")) resource_loader.get_path_to_datafile("_trt_ops.so"))

View File

@ -19,7 +19,9 @@ limitations under the License.
#include <vector> #include <vector>
#include "re2/re2.h" #include "re2/re2.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow { namespace tensorflow {
namespace tensorrt { namespace tensorrt {

View File

@ -16,8 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_
#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow { namespace tensorflow {
namespace tensorrt { namespace tensorrt {

View File

@ -1,45 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace tensorrt {
std::shared_ptr<TRTResourceManager>
tensorflow::tensorrt::TRTResourceManager::instance() {
static std::shared_ptr<TRTResourceManager> instance_(new TRTResourceManager);
return instance_;
}
std::shared_ptr<tensorflow::ResourceMgr>
tensorflow::tensorrt::TRTResourceManager::getManager(const string& op_name) {
// mutex is held for lookup only. Most instantiations where mutex will be held
// longer will be during op creation and should be ok.
tensorflow::mutex_lock lock(map_mutex_);
auto s = managers_.find(op_name);
if (s == managers_.end()) {
auto it = managers_.emplace(
op_name, std::make_shared<tensorflow::ResourceMgr>(op_name));
VLOG(1) << "Returning a new manager " << op_name;
return it.first->second;
}
VLOG(1) << "Returning old manager " << op_name;
return s->second;
}
} // namespace tensorrt
} // namespace tensorflow

View File

@ -1,45 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_
#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_
#include <memory>
#include <string>
#include <unordered_map>
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace tensorrt {
class TRTResourceManager {
TRTResourceManager() = default;
public:
static std::shared_ptr<TRTResourceManager> instance();
// returns a manager for given op, if it doesn't exists it creates one
std::shared_ptr<tensorflow::ResourceMgr> getManager(const string& op_name);
private:
std::unordered_map<string, std::shared_ptr<tensorflow::ResourceMgr>>
managers_;
tensorflow::mutex map_mutex_;
};
} // namespace tensorrt
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_

View File

@ -24,7 +24,7 @@ package(
) )
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library")
cc_library( cc_library(
name = "tf2xla_supported_ops_lib", name = "tf2xla_supported_ops_lib",
@ -60,6 +60,14 @@ xla_proto_library(
], ],
) )
xla_py_proto_library(
name = "tf2xla_py",
has_services = False,
api_version = 2,
visibility = ["//visibility:public"],
deps = [":tf2xla_proto"],
)
xla_proto_library( xla_proto_library(
name = "host_compute_metadata_proto", name = "host_compute_metadata_proto",
srcs = ["host_compute_metadata.proto"], srcs = ["host_compute_metadata.proto"],
@ -283,6 +291,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib",
], ],
) )

View File

@ -107,11 +107,13 @@ tf_kernel_library(
"xla_pad_op.cc", "xla_pad_op.cc",
"xla_reduce_op.cc", "xla_reduce_op.cc",
"xla_select_and_scatter_op.cc", "xla_select_and_scatter_op.cc",
"xla_self_adjoint_eig_op.cc",
], ],
hdrs = [ hdrs = [
"index_ops.h", "index_ops.h",
"shape_util.h", "shape_util.h",
], ],
tags = ["optonly"],
deps = [ deps = [
":conv_op_helpers", ":conv_op_helpers",
":if_op", ":if_op",
@ -143,6 +145,7 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/lib:prng",
"//tensorflow/compiler/xla/client/lib:qr", "//tensorflow/compiler/xla/client/lib:qr",
"//tensorflow/compiler/xla/client/lib:quantize", "//tensorflow/compiler/xla/client/lib:quantize",
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
"//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/core:bitwise_ops_op_lib", "//tensorflow/core:bitwise_ops_op_lib",
"//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:control_flow_ops_op_lib",

View File

@ -104,7 +104,7 @@ class SizeOp : public XlaOpKernel {
for (int64 i = 0; i < rank; ++i) { for (int64 i = 0; i < rank; ++i) {
size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i)); size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i));
} }
size = xla::ConvertElementType(size, xla::S32); size = xla::ConvertElementType(size, ctx->output_xla_type(0));
ctx->SetOutput(0, size); ctx->SetOutput(0, size);
} }
}; };

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
@ -35,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
@ -69,6 +71,43 @@ class TensorListLengthOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp); REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp);
// Creates an empty list with size (leading_dim, *element_shape) if
// element_shape is known at compile time. Otherwise creates one with size
// (leading_dim, 0) which gets initialized later in `GetInitializedList`.
Status CreateZerosList(XlaOpKernelContext* ctx, int element_shape_index,
int64 leading_dim, DataType dtype, xla::XlaOp* list) {
TensorShape list_shape;
list_shape.AddDim(leading_dim);
xla::XlaOp element_shape_handle = ctx->Input(element_shape_index);
TF_ASSIGN_OR_RETURN(
bool is_element_shape_compile_time_const,
element_shape_handle.builder()->IsConstant(element_shape_handle));
PartialTensorShape partial_element_shape;
if (is_element_shape_compile_time_const) {
TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape(
element_shape_index, &partial_element_shape));
}
if (is_element_shape_compile_time_const &&
partial_element_shape.IsFullyDefined()) {
TensorShape element_shape;
partial_element_shape.AsTensorShape(&element_shape);
list_shape.AppendShape(element_shape);
} else {
// If element_shape is not a compile time constant or if it is not fully
// defined we will have to wait for the first write call to fully allocate
// the array.
// TODO(srbs): We are using element_shape of [0] as a proxy to denote an
// uninitialized list. A better implementation may be to represent the
// list as a 3-tuple containining an explicit "initialized" flag. However,
// we would still need to create a dummy tensor for the first tuple
// element.
list_shape.AddDim(0);
}
*list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype),
list_shape.dim_sizes());
return Status::OK();
}
class TensorListReserveOp : public XlaOpKernel { class TensorListReserveOp : public XlaOpKernel {
public: public:
explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
@ -76,20 +115,15 @@ class TensorListReserveOp : public XlaOpKernel {
} }
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
TensorShape element_shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape));
int64 num_elements; int64 num_elements;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
TensorShape tensor_shape; xla::XlaOp list;
tensor_shape.AddDim(num_elements); OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &list));
tensor_shape.AppendShape(element_shape);
xla::XlaBuilder* b = ctx->builder(); xla::XlaBuilder* b = ctx->builder();
ctx->SetTensorListOutput( ctx->SetTensorListOutput(
0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), 0, xla::Tuple(b, {list, xla::ConstantR0<int32>(b, num_elements)}));
tensor_shape.dim_sizes()),
xla::ConstantR0<int32>(b, num_elements)}));
} }
private: private:
@ -110,8 +144,6 @@ class EmptyTensorListOp : public XlaOpKernel {
} }
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
TensorShape element_shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape));
int64 max_num_elements; int64 max_num_elements;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements)); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements));
OP_REQUIRES( OP_REQUIRES(
@ -119,15 +151,13 @@ class EmptyTensorListOp : public XlaOpKernel {
errors::InvalidArgument("XLA compilation requires a fixed tensor list " errors::InvalidArgument("XLA compilation requires a fixed tensor list "
"size. Set the max number of elements.")); "size. Set the max number of elements."));
TensorShape tensor_shape; xla::XlaOp list;
tensor_shape.AddDim(max_num_elements); OP_REQUIRES_OK(ctx,
tensor_shape.AppendShape(element_shape); CreateZerosList(ctx, 0, max_num_elements, dtype_, &list));
xla::XlaBuilder* b = ctx->builder(); xla::XlaBuilder* b = ctx->builder();
ctx->SetTensorListOutput( ctx->SetTensorListOutput(
0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), 0, xla::Tuple(b, {list, xla::ConstantR0<int32>(b, 0)}));
tensor_shape.dim_sizes()),
xla::ConstantR0<int32>(b, 0)}));
} }
private: private:
@ -274,6 +304,36 @@ REGISTER_XLA_OP(
Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"), Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"),
TensorListFromTensorOp); TensorListFromTensorOp);
// Returns the 0'th element of `tuple` containing the list tensor if it has been
// initialized already else creates one lazily. This allows lazy initialization
// of the list on the first call to SetItem or PushBack.
Status GetInitializedList(XlaOpKernelContext* ctx, const xla::XlaOp& tuple,
const TensorShape& element_shape, DataType dtype,
xla::XlaOp* list) {
*list = xla::GetTupleElement(tuple, 0);
TensorShape list_shape;
TF_RETURN_IF_ERROR(GetTensorListShape(ctx->builder(), tuple, &list_shape));
int64 leading_dim = list_shape.dim_size(0);
TensorShape list_element_shape = list_shape;
list_element_shape.RemoveDim(0);
// This checks for the lazy initialization contract set by CreateEmptyList.
// In TensorListReserve if the element_shape is not known at compile time,
// it creates a list with shape [leading_dim, 0].
if (element_shape != list_element_shape) {
if (list_element_shape.num_elements() != 0) {
return errors::InvalidArgument(
"Invalid shape of value in TensorListSetItem. Expected: ",
list_element_shape.DebugString(),
" Actual: ", element_shape.DebugString());
}
list_shape = element_shape;
list_shape.InsertDim(0, leading_dim);
*list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype),
list_shape.dim_sizes());
}
return Status::OK();
}
class TensorListSetItemOp : public XlaOpKernel { class TensorListSetItemOp : public XlaOpKernel {
public: public:
explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
@ -285,7 +345,9 @@ class TensorListSetItemOp : public XlaOpKernel {
xla::XlaOp tl = ctx->Input(0); xla::XlaOp tl = ctx->Input(0);
TensorShape elem_shape = ctx->InputShape(2); TensorShape elem_shape = ctx->InputShape(2);
xla::XlaOp ta = xla::GetTupleElement(tl, 0); xla::XlaOp list;
OP_REQUIRES_OK(ctx, GetInitializedList(ctx, tl, elem_shape, dtype_, &list));
xla::XlaOp index = ctx->Input(1); xla::XlaOp index = ctx->Input(1);
xla::XlaOp value = ctx->Input(2); xla::XlaOp value = ctx->Input(2);
@ -299,8 +361,8 @@ class TensorListSetItemOp : public XlaOpKernel {
auto update = xla::Reshape(value, slice_shape.dim_sizes()); auto update = xla::Reshape(value, slice_shape.dim_sizes());
ctx->SetTensorListOutput( ctx->SetTensorListOutput(
0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), 0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices),
index + xla::ConstantR0<int32>(b, 1)})); xla::GetTupleElement(tl, 1)}));
} }
private: private:
@ -319,11 +381,14 @@ class TensorListPushBackOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder(); xla::XlaBuilder* b = ctx->builder();
xla::XlaOp tl = ctx->Input(0); xla::XlaOp list_tuple = ctx->Input(0);
TensorShape elem_shape = ctx->InputShape(1); TensorShape elem_shape = ctx->InputShape(1);
xla::XlaOp ta = xla::GetTupleElement(tl, 0); xla::XlaOp list;
xla::XlaOp index = xla::GetTupleElement(tl, 1); OP_REQUIRES_OK(
ctx, GetInitializedList(ctx, list_tuple, elem_shape, dtype_, &list));
xla::XlaOp index = xla::GetTupleElement(list_tuple, 1);
xla::XlaOp value = ctx->Input(1); xla::XlaOp value = ctx->Input(1);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
@ -336,7 +401,7 @@ class TensorListPushBackOp : public XlaOpKernel {
auto update = xla::Reshape(value, slice_shape.dim_sizes()); auto update = xla::Reshape(value, slice_shape.dim_sizes());
ctx->SetTensorListOutput( ctx->SetTensorListOutput(
0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), 0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices),
index + xla::ConstantR0<int32>(b, 1)})); index + xla::ConstantR0<int32>(b, 1)}));
} }

View File

@ -0,0 +1,66 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
#include "tensorflow/core/lib/core/bits.h"
namespace tensorflow {
namespace {
class XlaSelfAdjointEigOp : public XlaOpKernel {
public:
explicit XlaSelfAdjointEigOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("lower", &lower_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_iter", &max_iter_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_));
}
void Compile(XlaOpKernelContext* ctx) override {
auto result =
xla::SelfAdjointEig(ctx->Input(0), lower_, max_iter_, epsilon_);
ctx->SetOutput(0, result.w);
ctx->SetOutput(1, result.v);
}
private:
bool lower_;
int32 max_iter_;
float epsilon_;
};
class SelfAdjointEigV2Op : public XlaOpKernel {
public:
explicit SelfAdjointEigV2Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape("input");
int n = input_shape.dim_size(input_shape.dims() - 1);
// This is based on heuristics that approx log(n) sweep updates are needed.
// Note: the heuristics provides no theoretical guarantee, max_iter=100 and
// epsilon should be used to determine exit condition.
int max_iter = 2 * tensorflow::Log2Ceiling(n);
auto result = xla::SelfAdjointEig(ctx->Input(0), true, max_iter, 1e-6);
ctx->SetOutput(0, result.w);
ctx->SetOutput(1, result.v);
}
};
REGISTER_XLA_OP(Name("XlaSelfAdjointEig").TypeConstraint("T", kFloatTypes),
XlaSelfAdjointEigOp);
REGISTER_XLA_OP(Name("SelfAdjointEigV2").TypeConstraint("T", kFloatTypes),
SelfAdjointEigV2Op);
} // namespace
} // namespace tensorflow

View File

@ -56,6 +56,41 @@ lhs_output: the broadcasted LHS tensor
rhs_output: the broadcasted RHS tensor rhs_output: the broadcasted RHS tensor
)doc"); )doc");
REGISTER_OP("XlaSelfAdjointEig")
.Input("a: T")
.Attr("lower: bool")
.Attr("max_iter: int")
.Attr("epsilon: float")
.Output("w: T")
.Output("v: T")
.SetShapeFn(shape_inference::UnknownShape)
.Attr("T: numbertype")
.Doc(R"doc(
Computes the eigen decomposition of a batch of self-adjoint matrices
(Note: Only real inputs are supported).
Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in
tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for
i=0...N-1.
a: the input tensor.
lower: a boolean specifies whether the calculation is done with the lower
triangular part or the upper triangular part.
max_iter: maximum number of sweep update, i.e., the whole lower triangular
part or upper triangular part based on parameter lower. Heuristically, it has
been argued that approximatly logN sweeps are needed in practice (Ref: Golub &
van Loan "Matrix Computation").
epsilon: the tolerance ratio.
w: The eigenvalues in ascending order, each repeated according to its
multiplicity.
v: The column v[..., :, i] is the normalized eigenvector corresponding to the
eigenvalue w[..., i].
)doc");
REGISTER_OP("XlaConv") REGISTER_OP("XlaConv")
.Input("lhs: T") .Input("lhs: T")
.Input("rhs: T") .Input("rhs: T")

View File

@ -291,6 +291,10 @@ def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
name=name) name=name)
def self_adjoint_eig(a, lower, max_iter, epsilon):
return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon)
dynamic_slice = gen_xla_ops.xla_dynamic_slice dynamic_slice = gen_xla_ops.xla_dynamic_slice
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice

View File

@ -185,9 +185,10 @@ Status BuildComputation(
std::vector<xla::XlaOp> elems; std::vector<xla::XlaOp> elems;
elems.reserve(retvals.size()); elems.reserve(retvals.size());
// Keeps track of which retvals have layout to update. The first element is // Keeps track of the layout of each retval. If a retval is not in this list,
// the output index, second element is the new layout. // a descending layout is used. The first element is the output index, second
std::vector<std::pair<int64, xla::Layout>> retval_to_update_layout; // element is the new layout.
std::vector<std::pair<int64, xla::Layout>> retval_index_and_layout;
for (int i = 0; i < retvals.size(); ++i) { for (int i = 0; i < retvals.size(); ++i) {
XlaCompiler::OutputDescription& output = (*outputs)[i]; XlaCompiler::OutputDescription& output = (*outputs)[i];
const XlaExpression& retval = retvals[i]; const XlaExpression& retval = retvals[i];
@ -216,7 +217,7 @@ Status BuildComputation(
TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
output.shape, output.type)); output.shape, output.type));
value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
retval_to_update_layout.emplace_back(elems.size(), shape.layout()); retval_index_and_layout.emplace_back(elems.size(), shape.layout());
} else if (it != retval_cores.end()) { } else if (it != retval_cores.end()) {
// Apply the sharding to the output, if there is a core assignment. // Apply the sharding to the output, if there is a core assignment.
value = identity_op(value); value = identity_op(value);
@ -289,6 +290,11 @@ Status BuildComputation(
// Ensures the correct sharding is applied to the output. // Ensures the correct sharding is applied to the output.
handle = identity_op(handle); handle = identity_op(handle);
// Set layout of the retval to device representation layout.
if (resource->representation_shape().has_value()) {
retval_index_and_layout.emplace_back(
elems.size(), resource->representation_shape()->layout());
}
elems.push_back(handle); elems.push_back(handle);
} }
} }
@ -318,15 +324,15 @@ Status BuildComputation(
computation->GetProgramShape()); computation->GetProgramShape());
*output_shape = program_shape.result(); *output_shape = program_shape.result();
// Update the output layout to the layout of retval. // Update the output layout to the layout of retval.
for (auto& update : retval_to_update_layout) { for (auto& index_and_layout : retval_index_and_layout) {
if (!always_return_tuple && elems.size() == 1) { if (!always_return_tuple && elems.size() == 1) {
*output_shape->mutable_layout() = update.second; *output_shape->mutable_layout() = index_and_layout.second;
continue; continue;
} }
xla::Shape* output_sub_shape = xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape(
xla::ShapeUtil::GetMutableSubshape(output_shape, {update.first}); output_shape, {index_and_layout.first});
*output_sub_shape->mutable_layout() = update.second; *output_sub_shape->mutable_layout() = index_and_layout.second;
} }
return Status::OK(); return Status::OK();
} }

View File

@ -277,6 +277,97 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) {
EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal)); EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
} }
// Tests that the compiler can correctly propagate the layout assigned by
// shape_representation_fn_ to return types.
TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
// Adds an identity op around the resource to make sure identity ops propagate
// resources correctly.
auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
auto write = ops::AssignAddVariableOp(scope, identity, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 3});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2, 3});
auto options = DefaultOptions();
options.shape_representation_fn =
[](const TensorShape& shape, DataType dt) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
return xla_shape;
};
// Compiles the graph.
XlaCompiler compiler(options);
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
xla::Shape transposed =
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
// Check that the return shapes are correctly tranposed.
EXPECT_EQ(result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
}
// The layout of resource variable shouldn't change after transpose
TEST_F(XlaCompilerTest, TransposeVariables) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
// Adds an identity op around the resource to make sure identity ops propagate
// resources correctly.
auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
auto write = ops::AssignAddVariableOp(scope, identity, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
auto transposed_read = ops::Transpose(scope, read, {1, 0});
auto reshape = ops::Reshape(scope, transposed_read, {2, 3});
auto d = ops::_Retval(scope.WithOpName("D"), reshape, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 3});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2, 3});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose",
std::move(graph), args, &result));
xla::Shape transposed =
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0});
// Check that the return shapes are correctly tranposed.
EXPECT_EQ(result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
}
// Tests that the compiler doesn't reorder the parameters. // Tests that the compiler doesn't reorder the parameters.
TEST_F(XlaCompilerTest, MixedOrderArguments) { TEST_F(XlaCompilerTest, MixedOrderArguments) {
for (bool swap_order : {false, true}) { for (bool swap_order : {false, true}) {

View File

@ -319,6 +319,27 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
return Status::OK(); return Status::OK();
} }
Status XlaOpKernelContext::ConstantInputAsPartialShape(
int index, PartialTensorShape* shape) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
// If `literal` is a scalar it's value must be -1.
if (literal.shape().rank() == 0) {
int64 shape_val;
TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val));
if (shape_val != -1) {
return errors::InvalidArgument(
"Cannot convert value to PartialTensorShape: ", shape_val);
}
*shape = PartialTensorShape(); // Shape with unknown rank.
return Status::OK();
}
std::vector<int64> dims;
TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
*shape = PartialTensorShape(dims);
return Status::OK();
}
Status XlaOpKernelContext::InputList(absl::string_view name, Status XlaOpKernelContext::InputList(absl::string_view name,
std::vector<xla::XlaOp>* handles, std::vector<xla::XlaOp>* handles,
std::vector<TensorShape>* shapes) { std::vector<TensorShape>* shapes) {
@ -447,6 +468,16 @@ void XlaOpKernelContext::SetOutputExpression(int index,
} }
} }
xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) {
xla::PrimitiveType type;
Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type);
if (!status.ok()) {
SetStatus(status);
return xla::PRIMITIVE_TYPE_INVALID;
}
return type;
}
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
SetOutputExpression( SetOutputExpression(
index, index,
@ -503,6 +534,7 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type,
handle = xla::Reshape(handle, handle = xla::Reshape(handle,
xla::AsInt64Slice(representation_shape.dimensions())); xla::AsInt64Slice(representation_shape.dimensions()));
} }
variable->SetRepresentationShape(representation_shape);
return variable->SetValue(handle); return variable->SetValue(handle);
} }

View File

@ -138,6 +138,10 @@ class XlaOpKernelContext {
// Converts a constant 1D int32 or int64 tensor into a TensorShape. // Converts a constant 1D int32 or int64 tensor into a TensorShape.
Status ConstantInputAsShape(int index, TensorShape* shape); Status ConstantInputAsShape(int index, TensorShape* shape);
// Converts a constant 1D int32 or int64 tensor, or a scalar with value -1
// into a PartialTensorShape.
Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape);
// Returns the named list-valued immutable input in "list", as // Returns the named list-valued immutable input in "list", as
// defined in the OpDef. If the named output is not list-valued, // defined in the OpDef. If the named output is not list-valued,
// returns a one-element list. // returns a one-element list.
@ -155,6 +159,11 @@ class XlaOpKernelContext {
return context_->expected_output_dtype(index); return context_->expected_output_dtype(index);
} }
// Returns the type of output `index` as an xla::PrimitiveType. If the type
// is not representable as an XLA type, sets an error status and returns
// xla::PRIMITIVE_TYPE_INVALID.
xla::PrimitiveType output_xla_type(int index);
// Sets output `index` to the XlaOp `handle`. // Sets output `index` to the XlaOp `handle`.
// All outputs should be set using SetOutput and SetConstantOutput, not // All outputs should be set using SetOutput and SetConstantOutput, not
// via the underlying OpKernelContext. // via the underlying OpKernelContext.

View File

@ -86,6 +86,12 @@ class XlaResource {
// variables have new values that need to be written back. // variables have new values that need to be written back.
const xla::XlaOp& initial_value() const { return initial_value_; } const xla::XlaOp& initial_value() const { return initial_value_; }
// An xla shape that indicates how this resource variable is represented on
// device.
const absl::optional<xla::Shape>& representation_shape() const {
return representation_shape_;
}
// A variable is initialized if it has a value. // A variable is initialized if it has a value.
bool initialized() const { return value_.valid(); } bool initialized() const { return value_.valid(); }
@ -100,6 +106,11 @@ class XlaResource {
// Sets the current value of the resource to an all-zero value. // Sets the current value of the resource to an all-zero value.
Status SetZeroValue(xla::XlaBuilder* builder); Status SetZeroValue(xla::XlaBuilder* builder);
// Sets the representational shape of the resource on device.
void SetRepresentationShape(const xla::Shape& shape) {
representation_shape_ = absl::make_optional(shape);
}
// Looks up the gradient for `source`, or creates it if it does not already // Looks up the gradient for `source`, or creates it if it does not already
// exist. The call target must be an initialized TensorArray resource. A // exist. The call target must be an initialized TensorArray resource. A
// TensorArray can have multiple named gradients; see the operator // TensorArray can have multiple named gradients; see the operator
@ -160,6 +171,10 @@ class XlaResource {
xla::XlaOp value_; xla::XlaOp value_;
xla::XlaOp initial_value_; xla::XlaOp initial_value_;
// An xla shape that indicates how this resource variable is represented on
// device.
absl::optional<xla::Shape> representation_shape_;
int64 max_array_size_ = -1; int64 max_array_size_ = -1;
bool tensor_array_multiple_writes_aggregate_ = false; bool tensor_array_multiple_writes_aggregate_ = false;

View File

@ -452,11 +452,12 @@ cc_library(
) )
cc_library( cc_library(
name = "self_adjoint_eigen", name = "self_adjoint_eig",
srcs = ["self_adjoint_eigen.cc"], srcs = ["self_adjoint_eig.cc"],
hdrs = ["self_adjoint_eigen.h"], hdrs = ["self_adjoint_eig.h"],
deps = [ deps = [
":arithmetic", ":arithmetic",
":comparators",
":constants", ":constants",
":loops", ":loops",
":math", ":math",
@ -473,9 +474,12 @@ cc_library(
) )
xla_test( xla_test(
name = "self_adjoint_eigen_test", name = "self_adjoint_eig_test",
size = "medium", srcs = ["self_adjoint_eig_test.cc"],
srcs = ["self_adjoint_eigen_test.cc"], blacklisted_backends = [
"cpu",
"gpu",
],
real_hardware_only = True, real_hardware_only = True,
shard_count = 10, shard_count = 10,
tags = ["optonly"], tags = ["optonly"],
@ -483,7 +487,7 @@ xla_test(
":arithmetic", ":arithmetic",
":constants", ":constants",
":matrix", ":matrix",
":self_adjoint_eigen", ":self_adjoint_eig",
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",

View File

@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h" #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/math.h"
@ -42,7 +43,6 @@ namespace {
struct SymmetricSchurDecomposition { struct SymmetricSchurDecomposition {
XlaOp c; // cosine. XlaOp c; // cosine.
XlaOp s; // sine. XlaOp s; // sine.
XlaOp reduction; // Reduction in the off diagonal after applying G.
}; };
// JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix // JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix
@ -51,7 +51,11 @@ struct SymmetricSchurDecomposition {
struct JacobiUpdate { struct JacobiUpdate {
XlaOp v; XlaOp v;
XlaOp w; XlaOp w;
};
struct FrobeniusNorms {
XlaOp off_diagonal_norm; XlaOp off_diagonal_norm;
XlaOp total_norm;
}; };
// Given an n-by-n symmetric A and integers p and q that satisfy 0 <= p < q < n, // Given an n-by-n symmetric A and integers p and q that satisfy 0 <= p < q < n,
@ -79,10 +83,6 @@ StatusOr<SymmetricSchurDecomposition> SymmetricShurDecomposition2x2(XlaOp a,
XlaBuilder* builder = a.builder(); XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
PrimitiveType type = a_shape.element_type();
const int64 num_dims = a_shape.rank();
auto zero = ScalarLike(a, 0.0); auto zero = ScalarLike(a, 0.0);
auto one = ScalarLike(a, 1.0); auto one = ScalarLike(a, 1.0);
auto two = ScalarLike(a, 2.0); auto two = ScalarLike(a, 2.0);
@ -110,9 +110,7 @@ StatusOr<SymmetricSchurDecomposition> SymmetricShurDecomposition2x2(XlaOp a,
schur.c = c * rnorm; schur.c = c * rnorm;
schur.s = s * rnorm; schur.s = s * rnorm;
schur.reduction =
Reduce(two * Square(pqs), zero, CreateScalarAddComputation(type, builder),
{num_dims - 2, num_dims - 1});
return schur; return schur;
} }
@ -196,12 +194,32 @@ StatusOr<JacobiUpdate> Update(JacobiUpdate jacobi_update, XlaOp p, XlaOp q,
jacobi_update.v = jacobi_update.v =
DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_q_new, {zero, q}); DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_q_new, {zero, q});
jacobi_update.off_diagonal_norm = Sqrt(
Max(Square(jacobi_update.off_diagonal_norm) - schur.reduction, pq_zero));
return jacobi_update; return jacobi_update;
} }
StatusOr<FrobeniusNorms> ComputeFrobeniusNorms(XlaOp w) {
XlaBuilder* builder = w.builder();
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w));
const int64 num_dims = shape.rank();
auto frobenius_norm =
Sqrt(Reduce(Square(w), ScalarLike(w, 0.0),
CreateScalarAddComputation(shape.element_type(), builder),
{num_dims - 2, num_dims - 1}));
auto diag = GetMatrixDiagonal(w);
auto diag_square =
Reduce(Square(diag), ScalarLike(w, 0.0),
CreateScalarAddComputation(shape.element_type(), builder),
{num_dims - 2});
FrobeniusNorms frobenius_norms;
frobenius_norms.off_diagonal_norm =
Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w, 0.0)));
frobenius_norms.total_norm = frobenius_norm;
return frobenius_norms;
}
StatusOr<std::vector<XlaOp>> WhileLoopFn( StatusOr<std::vector<XlaOp>> WhileLoopFn(
absl::Span<const XlaOp> initial_values, // absl::Span<const XlaOp> initial_values, //
int matrix_dimension, // int matrix_dimension, //
@ -212,62 +230,108 @@ StatusOr<std::vector<XlaOp>> WhileLoopFn(
auto while_cond_fn = [&](absl::Span<const XlaOp> values, auto while_cond_fn = [&](absl::Span<const XlaOp> values,
XlaBuilder* cond_builder) -> StatusOr<XlaOp> { XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
auto k = values[0]; auto k = values[0];
auto off_diagonal_norm = values[5];
// tol = frobenius_norm * epsilon.
auto tol = values[6] * values[7];
auto max_sweeps = ScalarLike(k, max_sweep_updates); auto max_sweeps = ScalarLike(k, max_sweep_updates);
auto sweep_update_cond = Gt(max_sweeps, k); auto sweep_update_cond = Gt(max_sweeps, k);
auto tol_cond = ReduceAll(Lt(tol, off_diagonal_norm), auto norms = ComputeFrobeniusNorms(values[2]).ValueOrDie();
auto tol = norms.total_norm * values[3];
auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm),
xla::ConstantR0<bool>(cond_builder, false), xla::ConstantR0<bool>(cond_builder, false),
CreateScalarOrComputation(PRED, cond_builder)); CreateScalarOrComputation(PRED, cond_builder));
return And(tol_cond, sweep_update_cond);
return And(sweep_update_cond, tol_cond);
}; };
auto while_body_fn = auto while_body_fn =
[&](absl::Span<const XlaOp> values, [&](absl::Span<const XlaOp> values,
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> { XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
auto zero = Zero(body_builder, index_type); auto while_cond_fn_inner =
auto one = One(body_builder, index_type); [&](absl::Span<const XlaOp> values_inner,
auto end_index = ScalarLike(one, matrix_dimension); XlaBuilder* inner_cond_builder) -> StatusOr<XlaOp> {
auto p = values_inner[0];
return Lt(p, ScalarLike(p, matrix_dimension - 1));
};
// Indexes. auto while_body_fn_inner =
XlaOp k = values[0]; [&](absl::Span<const XlaOp> values_inner,
XlaOp p = values[1]; XlaBuilder* inner_body_builder) -> StatusOr<std::vector<XlaOp>> {
XlaOp q = values[2]; auto while_cond_fn_innermost =
[&](absl::Span<const XlaOp> values_innermost,
XlaBuilder* innermost_cond_builder) -> StatusOr<XlaOp> {
auto q = values_innermost[1];
return Lt(q, ScalarLike(q, matrix_dimension));
};
auto while_body_fn_innermost =
[&](absl::Span<const XlaOp> values_innermost,
XlaBuilder* innermost_body_builder)
-> StatusOr<std::vector<XlaOp>> {
auto p = values_innermost[0];
auto q = values_innermost[1];
JacobiUpdate jacobi_update; JacobiUpdate jacobi_update;
jacobi_update.v = values[3]; jacobi_update.v = values_innermost[2];
jacobi_update.w = values[4]; jacobi_update.w = values_innermost[3];
jacobi_update.off_diagonal_norm = values[5];
XlaOp frobenius_norm = values[6]; auto tol = values_innermost[4];
XlaOp tol = values[7];
TF_ASSIGN_OR_RETURN(jacobi_update, TF_ASSIGN_OR_RETURN(jacobi_update,
Update(jacobi_update, p, q, tol, matrix_dimension)); Update(jacobi_update, p, q, tol, matrix_dimension));
std::vector<XlaOp> updated_values_innermost;
updated_values_innermost.reserve(values_innermost.size());
updated_values_innermost.push_back(p);
updated_values_innermost.push_back(q + ScalarLike(q, 1));
updated_values_innermost.push_back(jacobi_update.v);
updated_values_innermost.push_back(jacobi_update.w);
updated_values_innermost.push_back(tol);
return updated_values_innermost;
};
std::vector<XlaOp> values_innermost(5);
auto p = values_inner[0];
auto q = p + ScalarLike(p, 1);
values_innermost[0] = p; // index p.
values_innermost[1] = q; // index q.
values_innermost[2] = values_inner[1]; // v.
values_innermost[3] = values_inner[2]; // w.
values_innermost[4] = values_inner[3]; // tol.
TF_ASSIGN_OR_RETURN(
values_innermost,
WhileLoopHelper(while_cond_fn_innermost, while_body_fn_innermost,
values_innermost, absl::StrCat(name, "-Innermost"),
inner_body_builder));
std::vector<XlaOp> updated_values_inner;
updated_values_inner.reserve(values_inner.size());
updated_values_inner.push_back(p + ScalarLike(p, 1));
updated_values_inner.push_back(values_innermost[2]);
updated_values_inner.push_back(values_innermost[3]);
updated_values_inner.push_back(values_innermost[4]);
return updated_values_inner;
};
// Indexes.
XlaOp k = values[0];
std::vector<XlaOp> values_inner(4);
values_inner[0] = ScalarLike(k, 0); // index p.
values_inner[1] = values[1]; // v.
values_inner[2] = values[2]; // w.
values_inner[3] = values[3]; // tol.
TF_ASSIGN_OR_RETURN(
values_inner,
WhileLoopHelper(while_cond_fn_inner, while_body_fn_inner, values_inner,
absl::StrCat(name, "-Inner"), body_builder));
std::vector<XlaOp> updated_values; std::vector<XlaOp> updated_values;
updated_values.reserve(values.size()); updated_values.reserve(values_inner.size());
q = q + one; updated_values.push_back(k + ScalarLike(k, 1));
p = Select(Eq(q, end_index), p + one, p); updated_values.push_back(values_inner[1]);
k = Select(Eq(p, end_index - one), k + one, k); updated_values.push_back(values_inner[2]);
p = Select(Eq(p, end_index - one), zero, p); updated_values.push_back(values_inner[3]);
q = Select(Eq(q, end_index), p + one, q);
updated_values.push_back(k);
updated_values.push_back(p);
updated_values.push_back(q);
updated_values.push_back(jacobi_update.v);
updated_values.push_back(jacobi_update.w);
updated_values.push_back(jacobi_update.off_diagonal_norm);
updated_values.push_back(frobenius_norm);
updated_values.push_back(tol);
return updated_values; return updated_values;
}; };
@ -278,6 +342,27 @@ StatusOr<std::vector<XlaOp>> WhileLoopFn(
return values; return values;
} }
StatusOr<SelfAdjointEigResult> SortByEigenvalues(SelfAdjointEigResult result) {
XlaBuilder* builder = result.v.builder();
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.v));
const int64 num_dims = shape.rank();
auto dimensions = shape.dimensions();
std::vector<int64> broadcast_dims(num_dims - 1);
std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
broadcast_dims[num_dims - 2] = num_dims - 1;
result.w = BroadcastInDim(result.w, dimensions, broadcast_dims);
XlaOp sort_result =
Sort({result.w, result.v},
CreateScalarLtComputation(
{shape.element_type(), shape.element_type()}, builder),
num_dims - 1);
result.w = GetMatrixDiagonal(GetTupleElement(sort_result, 0));
result.v = GetTupleElement(sort_result, 1);
return result;
}
} // namespace } // namespace
// This is the cyclic Jacobi iteration. Please note that the eigenvalues are // This is the cyclic Jacobi iteration. Please note that the eigenvalues are
@ -286,31 +371,35 @@ StatusOr<std::vector<XlaOp>> WhileLoopFn(
// def jacobi(A): // def jacobi(A):
// n, _ = A.shape // n, _ = A.shape
// V = np.eye(n) // V = np.eye(n)
// nfrob = np.sum(A ** 2) // frobenius_norm = np.linalg.norm(A)
// ndiag = np.sum(np.diag(A) ** 2) // diag_norm = np.linalg.norm(np.diag(A))
// off = nfrob - ndiag // off_diag_norm = np.sqrt(
// while off > 1e-6 * nfrob: // frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm)
// while off_diag_norm > 1e-6 * frobenius_norm:
// for p in range(n - 1): // for p in range(n - 1):
// for q in range(p + 1, n): // for q in range(p + 1, n):
// if off > 1e-6 * nfrob:
// c, s = sym_schur2x2(A, p, q) // c, s = sym_schur2x2(A, p, q)
// off = off - 2 * A[p, q] ** 2
// A[[p, q], :] = np.matmul(np.array([[c, -s], [s, c]]), // A[[p, q], :] = np.matmul(np.array([[c, -s], [s, c]]),
// A[[p, q], :]) // A[[p, q], :])
// A[:, [p, q]] = np.matmul(A[:, [p, q]], // A[:, [p, q]] = np.matmul(A[:, [p, q]],
// np.array([[c, s], [-s, c]])) // np.array([[c, s], [-s, c]]))
// V[:, [p, q]] = np.matmul(V[:, [p, q]], // V[:, [p, q]] = np.matmul(V[:, [p, q]],
// np.array([[c, s], [-s, c]])) // np.array([[c, s], [-s, c]]))
// frobenius_norm_sq = np.linalg.norm(A)
// diag_square_sum = np.linalg.norm(np.diag(A))
// off_diag_norm = np.sqrt(
// frobenius_norm - diag_norm) * np.sqrt(
// frobenius_norm + diag_norm)
// //
// return A, V // return A, V
// //
// TODO(kuny): Implement parallel order Jacobi. // TODO(kuny): Implement parallel order Jacobi.
// //
SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower, int64 max_iter, SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter,
float epsilon) { float epsilon) {
XlaBuilder* builder = a.builder(); XlaBuilder* builder = a.builder();
auto return_error = [&](const Status& status) { auto return_error = [&](const Status& status) {
SelfAdjointEigenResult result; SelfAdjointEigResult result;
result.v = builder->ReportError(status); result.v = builder->ReportError(status);
result.w = builder->ReportError(status); result.w = builder->ReportError(status);
return result; return result;
@ -348,33 +437,17 @@ SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower, int64 max_iter,
batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
} }
auto zero = ScalarLike(a, 0.0);
auto tol = ScalarLike(a, epsilon); auto tol = ScalarLike(a, epsilon);
auto v_init = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); auto v_init = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
auto w_init = Triangle(a, lower); auto w_init = Triangle(a, lower);
w_init = w_init + TransposeInMinorDims(w_init) - w_init * v_init; w_init = w_init + TransposeInMinorDims(w_init) - w_init * v_init;
auto frobenius_norm = Sqrt(Reduce(Square(w_init), zero,
CreateScalarAddComputation(type, builder),
{num_dims - 2, num_dims - 1}));
auto diag = GetMatrixDiagonal(w_init);
auto diag_square =
Reduce(Square(diag), zero, CreateScalarAddComputation(type, builder),
{num_dims - 2});
auto off_diagonal_init =
Sqrt(Max(Square(frobenius_norm) - diag_square, zero));
auto output_with_status = WhileLoopFn( auto output_with_status = WhileLoopFn(
{ {
Zero(builder, S32), // k Zero(builder, S32), // k
Zero(builder, S32), // p v_init, // v
One(builder, S32), // q w_init, // w
v_init, //
w_init, //
off_diagonal_init, //
frobenius_norm, //
tol, // tol, //
}, // }, //
n, // n, //
@ -388,11 +461,11 @@ SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower, int64 max_iter,
auto output = output_with_status.ValueOrDie(); auto output = output_with_status.ValueOrDie();
SelfAdjointEigenResult result; SelfAdjointEigResult result;
result.v = output[3]; result.v = output[1];
result.w = GetMatrixDiagonal(output[4]); result.w = GetMatrixDiagonal(output[2]);
return result; return SortByEigenvalues(result).ValueOrDie();
} }
} // namespace xla } // namespace xla

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_ #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
@ -23,20 +23,18 @@ namespace xla {
// The eigenvalue decomposition of a symmetric matrix, the original matrix is // The eigenvalue decomposition of a symmetric matrix, the original matrix is
// recovered by v * w * v_t. // recovered by v * w * v_t.
struct SelfAdjointEigenResult { struct SelfAdjointEigResult {
// The i-th column is the normalized eigenvector corresponding to the // The i-th column is the normalized eigenvector corresponding to the
// eigenvalue w[i]. Will return a matrix object if a is a matrix object. // eigenvalue w[i]. Will return a matrix object if a is a matrix object.
XlaOp v; XlaOp v;
// TODO(kuny): Sort the eigenvalues.
// The eigenvalues in ascending order, each repeated according to its // The eigenvalues in ascending order, each repeated according to its
// multiplicity. // multiplicity.
XlaOp w; XlaOp w;
}; };
SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower = true, SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower = true,
int64 max_iter = 100, int64 max_iter = 100, float epsilon = 1e-6);
float epsilon = 1e-6);
} // namespace xla } // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_ #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h" #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
#include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array3d.h"
@ -32,7 +32,7 @@ limitations under the License.
namespace xla { namespace xla {
class SelfAdjointEigenTest : public ClientLibraryTestBase { class SelfAdjointEigTest : public ClientLibraryTestBase {
protected: protected:
void SetUp() override { void SetUp() override {
ClientLibraryTestBase::SetUp(); ClientLibraryTestBase::SetUp();
@ -71,7 +71,7 @@ class SelfAdjointEigenTest : public ClientLibraryTestBase {
} }
void TearDown() override { ClientLibraryTestBase::TearDown(); } void TearDown() override { ClientLibraryTestBase::TearDown(); }
Array3D<float> get_unit_matrix_3d(const Array3D<float>& matrix) { Array3D<float> GetUnitMatrix3D(const Array3D<float>& matrix) {
Array3D<float> result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0); Array3D<float> result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0);
for (int i = 0; i < matrix.n1(); ++i) { for (int i = 0; i < matrix.n1(); ++i) {
for (int j = 0; j < matrix.n2(); ++j) { for (int j = 0; j < matrix.n2(); ++j) {
@ -100,7 +100,7 @@ class SelfAdjointEigenTest : public ClientLibraryTestBase {
return result; return result;
} }
XlaOp ComputeMatmulVWVt(SelfAdjointEigenResult result, XlaBuilder* builder) { XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) {
Shape shape = builder->GetShape(result.v).ValueOrDie(); Shape shape = builder->GetShape(result.v).ValueOrDie();
std::vector<int64> out_dims = shape.dimensions(); std::vector<int64> out_dims = shape.dimensions();
std::vector<int64> broadcast_dims(shape.rank() - 1); std::vector<int64> broadcast_dims(shape.rank() - 1);
@ -140,69 +140,69 @@ class SelfAdjointEigenTest : public ClientLibraryTestBase {
Array2D<int> wrong_type_4x4_; Array2D<int> wrong_type_4x4_;
}; };
XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_2x4x4) { XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
XlaOp a; XlaOp a;
auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a); auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
auto result = SelfAdjointEigen(a); auto result = SelfAdjointEig(a);
ComputeMatmulVWVt(result, &builder); ComputeMatmulVWVt(result, &builder);
ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()}, ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
ErrorSpec(1e-3, 1e-3)); ErrorSpec(1e-3, 1e-3));
} }
XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_Lower_2x4x4) { XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
XlaOp a; XlaOp a;
auto a_data = CreateR3Parameter<float>( auto a_data = CreateR3Parameter<float>(
ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a); ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a);
auto result = SelfAdjointEigen(a); auto result = SelfAdjointEig(a);
ComputeMatmulVWVt(result, &builder); ComputeMatmulVWVt(result, &builder);
ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()}, ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
ErrorSpec(1e-3, 1e-3)); ErrorSpec(1e-3, 1e-3));
} }
XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_Upper_2x4x4) { XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
XlaOp a; XlaOp a;
auto a_data = CreateR3Parameter<float>( auto a_data = CreateR3Parameter<float>(
ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a); ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a);
auto result = SelfAdjointEigen(a, false); auto result = SelfAdjointEig(a, false);
ComputeMatmulVWVt(result, &builder); ComputeMatmulVWVt(result, &builder);
ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()}, ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
ErrorSpec(1e-3, 1e-3)); ErrorSpec(1e-3, 1e-3));
} }
XLA_TEST_F(SelfAdjointEigenTest, Test_Orthogonality_2x4x4) { XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
XlaOp a; XlaOp a;
auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a); auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
auto result = SelfAdjointEigen(a); auto result = SelfAdjointEig(a);
BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST); BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
ComputeAndCompareR3<float>(&builder, get_unit_matrix_3d(batch_3d_4x4_), ComputeAndCompareR3<float>(&builder, GetUnitMatrix3D(batch_3d_4x4_),
{a_data.get()}, ErrorSpec(1e-3, 1e-3)); {a_data.get()}, ErrorSpec(1e-3, 1e-3));
} }
XLA_TEST_F(SelfAdjointEigenTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) { XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
XlaOp a; XlaOp a;
auto a_data = CreateR2Parameter<float>(low_rank_4x4_, 0, "a", &builder, &a); auto a_data = CreateR2Parameter<float>(low_rank_4x4_, 0, "a", &builder, &a);
auto result = SelfAdjointEigen(a); auto result = SelfAdjointEig(a);
ComputeMatmulVWVt(result, &builder); ComputeMatmulVWVt(result, &builder);
ComputeAndCompareR2<float>(&builder, low_rank_4x4_, {a_data.get()}, ComputeAndCompareR2<float>(&builder, low_rank_4x4_, {a_data.get()},
ErrorSpec(1e-3, 1e-3)); ErrorSpec(1e-3, 1e-3));
} }
XLA_TEST_F(SelfAdjointEigenTest, Test_Eigen_8x8) { XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
// This is computed by numpy.linalg.eigh with float32. // This is computed by numpy.linalg.eigh with float32.
@ -211,21 +211,21 @@ XLA_TEST_F(SelfAdjointEigenTest, Test_Eigen_8x8) {
XlaOp a; XlaOp a;
auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a); auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
auto result = SelfAdjointEigen(a); auto result = SelfAdjointEig(a);
Sort(result.w); Add(result.w, ZerosLike(result.w));
ComputeAndCompareR1<float>(&builder, expected, {a_data.get()}, ComputeAndCompareR1<float>(&builder, expected, {a_data.get()},
ErrorSpec(1e-3, 1e-3)); ErrorSpec(1e-3, 1e-3));
} }
XLA_TEST_F(SelfAdjointEigenTest, Test_Orthogonality_8x8) { XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
float expected_vals = 1e-3; float expected_vals = 1e-3;
XlaOp a; XlaOp a;
auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a); auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
auto result = SelfAdjointEigen(a); auto result = SelfAdjointEig(a);
// np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2 // np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2
GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8), GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8),
BatchDot(TransposeInMinorDims(result.v), result.v), BatchDot(TransposeInMinorDims(result.v), result.v),
@ -235,66 +235,79 @@ XLA_TEST_F(SelfAdjointEigenTest, Test_Orthogonality_8x8) {
ErrorSpec(1e-3, 1e-3)); ErrorSpec(1e-3, 1e-3));
} }
XLA_TEST_F(SelfAdjointEigenTest, Wrong_Type_Int) { XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
XlaOp a; XlaOp a;
auto a_data = CreateR2Parameter<int>(wrong_type_4x4_, 0, "a", &builder, &a); auto a_data = CreateR2Parameter<int>(wrong_type_4x4_, 0, "a", &builder, &a);
auto result = SelfAdjointEigen(a); auto result = SelfAdjointEig(a);
EXPECT_FALSE(result.v.valid()); EXPECT_FALSE(result.v.valid());
EXPECT_FALSE(result.w.valid()); EXPECT_FALSE(result.w.valid());
} }
XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_8x8) { XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_8x8) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
int size = 8; int size = 8;
Array2D<float> a_val = GenerateRandomSymmetricMatrix(size); Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
XlaOp a; XlaOp a;
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a); auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
auto result = SelfAdjointEigen(a); auto result = SelfAdjointEig(a);
GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()}, ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
ErrorSpec(1e-2, 1e-2)); ErrorSpec(1e-3, 1e-3));
} }
XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_16x16) { XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_16x16) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
int size = 16; int size = 16;
Array2D<float> a_val = GenerateRandomSymmetricMatrix(size); Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
XlaOp a; XlaOp a;
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a); auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
auto result = SelfAdjointEigen(a); auto result = SelfAdjointEig(a);
GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()}, ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
ErrorSpec(1e-2, 1e-2)); ErrorSpec(1e-3, 1e-3));
} }
XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_32x32) { XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_32x32) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
int size = 32; int size = 32;
Array2D<float> a_val = GenerateRandomSymmetricMatrix(size); Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
XlaOp a; XlaOp a;
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a); auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
auto result = SelfAdjointEigen(a); auto result = SelfAdjointEig(a);
GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()}, ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
ErrorSpec(1e-2, 1e-2)); ErrorSpec(1e-3, 1e-3));
} }
XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_64x64) { XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_256x256) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
int size = 64; int size = 256;
Array2D<float> a_val = GenerateRandomSymmetricMatrix(size); Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
XlaOp a; XlaOp a;
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a); auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
auto result = SelfAdjointEigen(a); auto result = SelfAdjointEig(a);
GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()}, ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
ErrorSpec(1e-2, 1e-2)); ErrorSpec(1e-3, 1e-3));
}
XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_512x512) {
XlaBuilder builder(TestName());
int size = 512;
Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
XlaOp a;
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
auto result = SelfAdjointEig(a);
GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
ErrorSpec(1e-3, 1e-3));
} }
} // namespace xla } // namespace xla

View File

@ -36,7 +36,8 @@ XlaOp TopK(XlaOp input, int64 k) {
XlaOp sort_result = XlaOp sort_result =
Sort({Neg(input), iota_s32}, Sort({Neg(input), iota_s32},
CreateScalarLtComputation({input_shape.element_type(), S32}, CreateScalarLtComputation({input_shape.element_type(), S32},
iota_s32.builder())); iota_s32.builder()),
last_dim, /*is_stable=*/true);
std::vector<int64> start_indices(input_shape.dimensions_size(), 0); std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end()); std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
limit_indices[last_dim] = k; limit_indices[last_dim] = k;

View File

@ -81,9 +81,7 @@ XLA_TEST_F(SortingTest, TopKFullSort) {
ComputeAndCompareR1<float>(&builder, inputs, {}); ComputeAndCompareR1<float>(&builder, inputs, {});
} }
// TODO(b/122298745): Enable this test when the GPU backend supports stable XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates) {
// sorting.
XLA_TEST_F(SortingTest, DISABLED_ON_GPU(TopKFullSortWithDuplicates)) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
XlaOp a; XlaOp a;
auto a_data = CreateR1Parameter<int>({1, 1, 2, 2, 1}, 0, "a", &builder, &a); auto a_data = CreateR1Parameter<int>({1, 1, 2, 2, 1}, 0, "a", &builder, &a);

View File

@ -1663,14 +1663,16 @@ XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span<const XlaOp> values,
Lt(first_lhs_param, first_rhs_param); Lt(first_lhs_param, first_rhs_param);
TF_ASSIGN_OR_RETURN(auto comparator, b->Build()); TF_ASSIGN_OR_RETURN(auto comparator, b->Build());
return Sort(operands, comparator, dimension); return Sort(operands, comparator, dimension, /*is_stable=*/false);
}); });
} }
XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands, XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
const XlaComputation& comparator, int64 dimension) { const XlaComputation& comparator, int64 dimension,
bool is_stable) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr; HloInstructionProto instr;
instr.set_is_stable(is_stable);
std::vector<const Shape*> operand_shape_ptrs; std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes, TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
GetOperandShapes(operands)); GetOperandShapes(operands));
@ -3320,8 +3322,9 @@ XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values, int64 dimension) {
} }
XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator, XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
int64 dimension) { int64 dimension, bool is_stable) {
return operands[0].builder()->Sort(operands, comparator, dimension); return operands[0].builder()->Sort(operands, comparator, dimension,
is_stable);
} }
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) {

View File

@ -505,7 +505,7 @@ class XlaBuilder {
XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {}, XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
int64 dimension = -1); int64 dimension = -1);
XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator, XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
int64 dimension = -1); int64 dimension = -1, bool is_stable = false);
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
@ -923,7 +923,8 @@ class XlaBuilder {
friend XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values, friend XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values,
int64 dimension); int64 dimension);
friend XlaOp Sort(absl::Span<const XlaOp> operands, friend XlaOp Sort(absl::Span<const XlaOp> operands,
const XlaComputation& comparator, int64 dimension); const XlaComputation& comparator, int64 dimension,
bool is_stable);
friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands, friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
const XlaComputation& computation, const XlaComputation& computation,
@ -1695,7 +1696,8 @@ XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
int64 dimension = -1); int64 dimension = -1);
// Enqueues a sort instruction onto the computation, using 'comparator' for // Enqueues a sort instruction onto the computation, using 'comparator' for
// comparisons. 'comparator' needs to define a strict weak order. // comparisons. 'comparator' needs to define a strict weak order. 'is_stable'
// determines whether the stable sorting should be used.
// If only one operand is provided: // If only one operand is provided:
// * If the operand is a rank-1 tensor (an array), the result is a sorted array. // * If the operand is a rank-1 tensor (an array), the result is a sorted array.
// The resulting sorting order has the property that for all index positions // The resulting sorting order has the property that for all index positions
@ -1718,7 +1720,7 @@ XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
// correspond to the value of operand i at two index positions. // correspond to the value of operand i at two index positions.
// Default comparator computations can be found in lib/comparators.h // Default comparator computations can be found in lib/comparators.h
XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator, XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
int64 dimension = -1); int64 dimension = -1, bool is_stable = false);
// Enqueues a clamp instruction onto the computation. // Enqueues a clamp instruction onto the computation.
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);

View File

@ -77,7 +77,7 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const {
} }
ExecutableRunOptions& ExecutableRunOptions::set_device_assignment( ExecutableRunOptions& ExecutableRunOptions::set_device_assignment(
DeviceAssignment* device_assignment) { const DeviceAssignment* device_assignment) {
device_assignment_ = device_assignment; device_assignment_ = device_assignment;
return *this; return *this;
} }

View File

@ -74,7 +74,7 @@ class ExecutableRunOptions {
ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile); ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile);
ExecutableRunOptions& set_device_assignment( ExecutableRunOptions& set_device_assignment(
DeviceAssignment* device_assignment); const DeviceAssignment* device_assignment);
const DeviceAssignment* device_assignment() const; const DeviceAssignment* device_assignment() const;
ExecutableRunOptions& set_rng_seed(int rng_seed); ExecutableRunOptions& set_rng_seed(int rng_seed);
@ -83,7 +83,7 @@ class ExecutableRunOptions {
private: private:
DeviceMemoryAllocator* allocator_ = nullptr; DeviceMemoryAllocator* allocator_ = nullptr;
int device_ordinal_ = -1; int device_ordinal_ = -1;
DeviceAssignment* device_assignment_ = nullptr; const DeviceAssignment* device_assignment_ = nullptr;
stream_executor::Stream* stream_ = nullptr; stream_executor::Stream* stream_ = nullptr;
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
ExecutionProfile* execution_profile_ = nullptr; ExecutionProfile* execution_profile_ = nullptr;

View File

@ -77,6 +77,7 @@ cc_library(
"//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:cholesky",
"//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:qr", "//tensorflow/compiler/xla/client/lib:qr",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
@ -53,74 +54,6 @@ namespace swig {
// TODO(b/118641336): Factor out XRT parts into a small c++ library of their // TODO(b/118641336): Factor out XRT parts into a small c++ library of their
// own. // own.
// TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of
// device handles instead of needing to set the number of replicas at XLA
// service initialization time.
tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED);
int g_replica_count GUARDED_BY(g_local_client_mutex) = 1;
LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr;
string* GetPlatformNameString() {
static string* platform_name_string PT_GUARDED_BY(g_local_client_mutex) =
new string("Host");
return platform_name_string;
}
Status InitializeReplicaCount(int replica_count) {
if (replica_count < 1) {
return InvalidArgument("Replica count must be >= 1; got %d.",
replica_count);
}
tensorflow::mutex_lock lock(g_local_client_mutex);
if (g_local_client != nullptr) {
return FailedPrecondition(
"Attempted to set the replica count to %d, but a local XLA service was "
"previously created with a replica count of %d.",
replica_count, g_replica_count);
}
g_replica_count = replica_count;
return Status::OK();
}
Status InitializePlatformName(const string& platform_name) {
string* g_platform_name = GetPlatformNameString();
tensorflow::mutex_lock lock(g_local_client_mutex);
if (g_local_client != nullptr) {
return FailedPrecondition(
"Attempted to set the platform name to %s, but a local XLA service was "
"previously created with a platform name of %s.",
platform_name, *g_platform_name);
}
TF_ASSIGN_OR_RETURN(se::Platform * platform,
PlatformUtil::GetPlatform(platform_name));
if (platform->VisibleDeviceCount() <= 0) {
return InvalidArgument("Platform %s has no visible devices.",
platform_name);
}
*g_platform_name = platform_name;
return Status::OK();
}
int GetReplicaCount() {
tensorflow::mutex_lock lock(g_local_client_mutex);
return g_replica_count;
}
StatusOr<LocalClient*> GetOrCreateLocalClient() {
string* platform_name = GetPlatformNameString();
tensorflow::mutex_lock lock(g_local_client_mutex);
if (g_local_client != nullptr) {
return g_local_client;
}
LocalClientOptions options;
options.set_platform(PlatformUtil::GetPlatform(*platform_name).ValueOrDie());
options.set_number_of_replicas(g_replica_count);
TF_ASSIGN_OR_RETURN(g_local_client,
ClientLibrary::GetOrCreateLocalClient(options));
CHECK(g_local_client != nullptr);
return g_local_client;
}
Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) { Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) {
const char* name = "xla._CPU_CUSTOM_CALL_TARGET"; const char* name = "xla._CPU_CUSTOM_CALL_TARGET";
if (!PyCapsule_IsValid(capsule, name)) { if (!PyCapsule_IsValid(capsule, name)) {
@ -135,62 +68,66 @@ Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) {
return Status::OK(); return Status::OK();
} }
Status TransferToInfeedLocal(const Literal& literal) { LocalClient::LocalClient(xla::LocalClient* client) : client_(client) {}
VLOG(1) << "Infeeding literal without replica number; shape: "
<< literal.shape(); /* static */ StatusOr<LocalClient> LocalClient::Get(
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); const string& platform_name) {
return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0); TF_ASSIGN_OR_RETURN(se::Platform * platform,
PlatformUtil::GetPlatform(platform_name));
if (platform->VisibleDeviceCount() <= 0) {
return InvalidArgument("Platform %s has no visible devices.",
platform_name);
}
LocalClientOptions options;
options.set_platform(platform);
TF_ASSIGN_OR_RETURN(xla::LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options));
CHECK(client != nullptr);
return LocalClient(client);
} }
Status TransferToInfeedLocalReplica(const Literal& literal, // Returns the number of devices known to the XLA client.
int replica_number) { int LocalClient::DeviceCount() const { return client_->device_count(); }
VLOG(1) << "Infeeding shape " << literal.shape()
<< " to replica number: " << replica_number; Status LocalClient::TransferToInfeed(const Literal& literal,
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); int device_ordinal) {
TF_ASSIGN_OR_RETURN(int device_ordinal, VLOG(1) << "Infeeding literal to device " << device_ordinal
client->ReplicaNumberToDeviceOrdinal(replica_number)); << "; shape: " << literal.shape();
return client->TransferToInfeedLocal(literal, device_ordinal); return client_->TransferToInfeed(literal, device_ordinal);
} }
StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape, StatusOr<Literal> LocalClient::TransferFromOutfeed(const Shape& shape,
int replica_number) { int device_ordinal) {
VLOG(1) << "Outfeeding literal from replica number: " << replica_number VLOG(1) << "Outfeeding literal from device " << device_ordinal
<< " shape: " << shape; << "; shape: " << shape;
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); return client_->TransferFromOutfeed(&shape, device_ordinal);
TF_ASSIGN_OR_RETURN(int device_ordinal,
client->ReplicaNumberToDeviceOrdinal(replica_number));
return client->TransferFromOutfeedLocal(shape, device_ordinal);
}
static StatusOr<ScopedShapedBuffer> ToBuffer(LocalClient* client,
int device_ordinal,
const Literal& arg) {
return client->LiteralToShapedBuffer(arg, device_ordinal,
client->backend().memory_allocator());
} }
/* static */ /* static */
StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral( StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
const Literal& argument, const absl::optional<Shape>& shape_with_layout, const Literal& argument, const absl::optional<Shape>& shape_with_layout,
int replica_number) { const LocalClient& client, int device_ordinal) {
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); VLOG(1) << "Creating shaped buffer from literal on device ordinal: "
TF_ASSIGN_OR_RETURN(int device_ordinal, << device_ordinal;
client->ReplicaNumberToDeviceOrdinal(replica_number)); auto literal_to_buffer = [&](const Literal& arg) {
VLOG(1) << "Creating shaped buffer from literal on replica/ordinal: " return client.client()->LiteralToShapedBuffer(
<< replica_number << "/" << device_ordinal; arg, device_ordinal, client.client()->backend().memory_allocator());
};
StatusOr<ScopedShapedBuffer> buf = [&] { StatusOr<ScopedShapedBuffer> buf = [&] {
if (shape_with_layout) { if (shape_with_layout) {
Literal relaid = argument.Relayout(shape_with_layout.value()); Literal relaid = argument.Relayout(shape_with_layout.value());
return ToBuffer(client, device_ordinal, relaid); return literal_to_buffer(relaid);
} }
return ToBuffer(client, device_ordinal, argument); return literal_to_buffer(argument);
}(); }();
TF_RETURN_IF_ERROR(buf.status()); TF_RETURN_IF_ERROR(buf.status());
return new LocalShapedBuffer(std::move(buf).ValueOrDie()); return new LocalShapedBuffer(std::move(buf).ValueOrDie(), client.client());
} }
LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer) LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer,
: shaped_buffer_(std::move(shaped_buffer)) {} xla::LocalClient* client)
: shaped_buffer_(std::move(shaped_buffer)), client_(client) {}
const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const { const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const {
return &shaped_buffer_; return &shaped_buffer_;
@ -203,8 +140,7 @@ const Shape& LocalShapedBuffer::shape() const {
} }
StatusOr<Literal> LocalShapedBuffer::ToLiteral() const { StatusOr<Literal> LocalShapedBuffer::ToLiteral() const {
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); return client_->ShapedBufferToLiteral(*shaped_buffer());
return client->ShapedBufferToLiteral(*shaped_buffer());
} }
LocalShapedBufferTuple::LocalShapedBufferTuple( LocalShapedBufferTuple::LocalShapedBufferTuple(
@ -235,6 +171,51 @@ StatusOr<LocalShapedBuffer*> LocalShapedBufferTuple::Release(int i) {
int64 LocalShapedBufferTuple::size() const { return elements_.size(); } int64 LocalShapedBufferTuple::size() const { return elements_.size(); }
StatusOr<LocalShapedBufferTuple*> LocalShapedBuffer::DestructureTuple() {
const Shape tuple_shape = shape();
if (!tuple_shape.IsTuple()) {
return InvalidArgument(
"Attemped to destructure a LocalShapedBuffer that did not have a tuple "
"shape; shape: %s",
ShapeUtil::HumanString(tuple_shape));
}
DeviceMemoryAllocator* allocator = shaped_buffer()->memory_allocator();
ShapedBuffer tuple_buffer = Release();
// Extract some metadata we use to construct scoped buffers.
const se::Platform* platform = tuple_buffer.platform();
int device_ordinal = tuple_buffer.device_ordinal();
ShapeTree<se::DeviceMemoryBase>& shape_tree = tuple_buffer.buffers();
std::vector<LocalShapedBuffer*> results;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
// Create a shaped buffer for this destructured tuple element.
const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i});
VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape;
ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal);
ShapeUtil::ForEachSubshape(
subshape, [&](const Shape& s, const ShapeIndex& index) {
ShapeIndex original(index);
original.push_front(i);
se::DeviceMemoryBase* device_memory =
shape_tree.mutable_element(original);
shaped_buffer.set_buffer(*device_memory, index);
*device_memory = se::DeviceMemoryBase();
});
VLOG(3) << "Completed tuple element: " << i;
results.push_back(new LocalShapedBuffer(
ScopedShapedBuffer(std::move(shaped_buffer), allocator), client_));
}
// Deallocate the root buffer.
se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer();
TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer));
return new LocalShapedBufferTuple(std::move(results));
}
XrtAllocation::XrtAllocation(int64 handle, Shape shape, XrtAllocation::XrtAllocation(int64 handle, Shape shape,
const string& session_target) const string& session_target)
: handle_(handle), shape_(shape), session_target_(session_target) {} : handle_(handle), shape_(shape), session_target_(session_target) {}
@ -332,23 +313,32 @@ StatusOr<XrtAllocation*> XrtAllocationTuple::Release(int i) {
int64 XrtAllocationTuple::size() const { return elements_.size(); } int64 XrtAllocationTuple::size() const { return elements_.size(); }
CompiledLocalComputation::CompiledLocalComputation( LocalExecutable::LocalExecutable(
std::unique_ptr<LocalExecutable> executable) std::unique_ptr<xla::LocalExecutable> executable,
: executable_(std::move(executable)) {} xla::DeviceAssignment device_assignment, xla::LocalClient* client)
: executable_(std::move(executable)),
device_assignment_(std::move(device_assignment)),
client_(client) {}
StatusOr<LocalShapedBuffer*> CompiledLocalComputation::Execute( std::vector<int> LocalExecutable::DeviceOrdinals() const {
int num_replicas = device_assignment_.replica_count();
std::vector<int> device_ordinals;
device_ordinals.reserve(num_replicas);
for (int i = 0; i < num_replicas; ++i) {
device_ordinals.push_back(device_assignment_(i, 0));
}
return device_ordinals;
}
StatusOr<LocalShapedBuffer*> LocalExecutable::Execute(
absl::Span<LocalShapedBuffer* const> argument_handles) { absl::Span<LocalShapedBuffer* const> argument_handles) {
if (num_replicas() != 1) { if (num_replicas() != 1) {
return InvalidArgument( return InvalidArgument(
"Attempted to execute computation with %d replicas using Execute()", "Attempted to execute computation with %d replicas using Execute()",
num_replicas()); num_replicas());
} }
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient());
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
client->backend().computation_placer()->AssignDevices(
1, /*computation_count=*/1));
StatusOr<ScopedShapedBuffer> result_buffer_status; StatusOr<ScopedShapedBuffer> result_buffer_status;
const int device_ordinal = device_assignment(0, 0); const int device_ordinal = device_assignment_(0, 0);
VLOG(3) << "Replica 0 mapped to device ordinal for execution: " VLOG(3) << "Replica 0 mapped to device ordinal for execution: "
<< device_ordinal; << device_ordinal;
@ -360,10 +350,10 @@ StatusOr<LocalShapedBuffer*> CompiledLocalComputation::Execute(
ExecutableRunOptions options; ExecutableRunOptions options;
options.set_device_ordinal(device_ordinal); options.set_device_ordinal(device_ordinal);
options.set_allocator(client->backend().memory_allocator()); options.set_allocator(client_->backend().memory_allocator());
options.set_intra_op_thread_pool( options.set_intra_op_thread_pool(
client->backend().eigen_intra_op_thread_pool_device()); client_->backend().eigen_intra_op_thread_pool_device());
options.set_device_assignment(&device_assignment); options.set_device_assignment(&device_assignment_);
result_buffer_status = executable_->Run(argument_buffers, options); result_buffer_status = executable_->Run(argument_buffers, options);
@ -373,13 +363,13 @@ StatusOr<LocalShapedBuffer*> CompiledLocalComputation::Execute(
"%s.", "%s.",
result_buffer_status.status().ToString()); result_buffer_status.status().ToString());
} }
return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie()); return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(),
client_);
} }
StatusOr<LocalShapedBufferTuple*> CompiledLocalComputation::ExecutePerReplica( StatusOr<LocalShapedBufferTuple*> LocalExecutable::ExecutePerReplica(
absl::Span<const std::vector<LocalShapedBuffer*>> argument_handles) { absl::Span<const std::vector<LocalShapedBuffer*>> argument_handles) {
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); const int num_devices = client_->device_count();
const int num_devices = client->device_count();
if (argument_handles.size() != num_replicas()) { if (argument_handles.size() != num_replicas()) {
return InvalidArgument( return InvalidArgument(
@ -394,14 +384,9 @@ StatusOr<LocalShapedBufferTuple*> CompiledLocalComputation::ExecutePerReplica(
VLOG(1) << "Executing with " << num_replicas() << " replicas."; VLOG(1) << "Executing with " << num_replicas() << " replicas.";
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
client->backend().computation_placer()->AssignDevices(
num_replicas(), /*computation_count=*/1));
std::vector<StatusOr<ScopedShapedBuffer>> results(num_replicas()); std::vector<StatusOr<ScopedShapedBuffer>> results(num_replicas());
auto execute = [this, client, &device_assignment, &argument_handles, auto execute = [this, &argument_handles, &results](int replica) {
&results](int replica) { const int device_ordinal = device_assignment_(replica, 0);
const int device_ordinal = device_assignment(replica, 0);
VLOG(3) << "Replica " << replica VLOG(3) << "Replica " << replica
<< " mapped to device ordinal for execution: " << device_ordinal; << " mapped to device ordinal for execution: " << device_ordinal;
@ -413,10 +398,10 @@ StatusOr<LocalShapedBufferTuple*> CompiledLocalComputation::ExecutePerReplica(
ExecutableRunOptions options; ExecutableRunOptions options;
options.set_device_ordinal(device_ordinal); options.set_device_ordinal(device_ordinal);
options.set_allocator(client->backend().memory_allocator()); options.set_allocator(client_->backend().memory_allocator());
options.set_intra_op_thread_pool( options.set_intra_op_thread_pool(
client->backend().eigen_intra_op_thread_pool_device()); client_->backend().eigen_intra_op_thread_pool_device());
options.set_device_assignment(&device_assignment); options.set_device_assignment(&device_assignment_);
StatusOr<ScopedShapedBuffer> result_buffer_status = StatusOr<ScopedShapedBuffer> result_buffer_status =
executable_->Run(argument_buffers, options); executable_->Run(argument_buffers, options);
@ -448,26 +433,19 @@ StatusOr<LocalShapedBufferTuple*> CompiledLocalComputation::ExecutePerReplica(
replica, statusor.status().ToString()); replica, statusor.status().ToString());
} }
wrapped_results[replica] = wrapped_results[replica] =
new LocalShapedBuffer(std::move(statusor).ValueOrDie()); new LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_);
} }
return new LocalShapedBufferTuple(std::move(wrapped_results)); return new LocalShapedBufferTuple(std::move(wrapped_results));
} }
static StatusOr<Shape> GetReturnValueShape(const XlaComputation& computation) { XrtExecutable::XrtExecutable(const ProgramShape& program_shape, int64 handle,
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation.GetProgramShape());
return std::move(*program_shape.mutable_result());
}
CompiledXrtComputation::CompiledXrtComputation(
const ProgramShape& program_shape, int64 handle,
const string& session_target) const string& session_target)
: program_shape_(program_shape), : program_shape_(program_shape),
handle_(handle), handle_(handle),
session_target_(session_target) {} session_target_(session_target) {}
CompiledXrtComputation::~CompiledXrtComputation() { XrtExecutable::~XrtExecutable() {
tensorflow::Scope root = tensorflow::Scope::NewRootScope(); tensorflow::Scope root = tensorflow::Scope::NewRootScope();
auto computation_handle = auto computation_handle =
tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); tensorflow::ops::Placeholder(root, tensorflow::DT_INT64);
@ -489,7 +467,7 @@ CompiledXrtComputation::~CompiledXrtComputation() {
} }
} }
StatusOr<XrtAllocation*> CompiledXrtComputation::Execute( StatusOr<XrtAllocation*> XrtExecutable::Execute(
absl::Span<XrtAllocation* const> argument_handles) { absl::Span<XrtAllocation* const> argument_handles) {
const int num_expected_arguments = program_shape().parameters().size(); const int num_expected_arguments = program_shape().parameters().size();
@ -528,36 +506,41 @@ StatusOr<XrtAllocation*> CompiledXrtComputation::Execute(
return new XrtAllocation(output, program_shape().result(), session_target_); return new XrtAllocation(output, program_shape().result(), session_target_);
} }
const ProgramShape& CompiledXrtComputation::program_shape() const { const ProgramShape& XrtExecutable::program_shape() const {
return program_shape_; return program_shape_;
} }
int64 CompiledXrtComputation::handle() const { return handle_; } int64 XrtExecutable::handle() const { return handle_; }
LocalComputation::LocalComputation(XlaComputation computation) Computation::Computation(XlaComputation computation)
: computation_(std::move(computation)) {} : computation_(std::move(computation)) {}
StatusOr<CompiledLocalComputation*> LocalComputation::Compile( StatusOr<LocalExecutable*> Computation::Compile(
const std::vector<Shape>& argument_shapes, const std::vector<Shape>& argument_shapes,
const ExecutableBuildOptions* build_options) { const ExecutableBuildOptions* build_options, const LocalClient& client) {
std::vector<const Shape*> argument_shape_pointers; std::vector<const Shape*> argument_shape_pointers;
argument_shape_pointers.reserve(argument_shapes.size()); argument_shape_pointers.reserve(argument_shapes.size());
for (auto& argument_shape : argument_shapes) { for (auto& argument_shape : argument_shapes) {
argument_shape_pointers.push_back(&argument_shape); argument_shape_pointers.push_back(&argument_shape);
} }
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient());
ExecutableBuildOptions options; ExecutableBuildOptions options;
if (build_options != nullptr) { if (build_options != nullptr) {
options = *build_options; options = *build_options;
} }
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto local_executable, auto local_executable,
client->Compile(computation_, argument_shape_pointers, options)); client.client()->Compile(computation_, argument_shape_pointers, options));
return new CompiledLocalComputation(std::move(local_executable)); TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment,
client.client()->backend().computation_placer()->AssignDevices(
options.num_replicas(), /*computation_count=*/1));
return new LocalExecutable(std::move(local_executable),
std::move(device_assignment), client.client());
} }
StatusOr<CompiledXrtComputation*> LocalComputation::CompileForXrt( StatusOr<XrtExecutable*> Computation::CompileForXrt(
const std::vector<Shape>& argument_shapes, const string& session_target) { const std::vector<Shape>& argument_shapes, const string& session_target) {
tensorflow::Scope root = tensorflow::Scope::NewRootScope(); tensorflow::Scope root = tensorflow::Scope::NewRootScope();
auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING);
@ -585,14 +568,12 @@ StatusOr<CompiledXrtComputation*> LocalComputation::CompileForXrt(
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation().GetProgramShape()); computation().GetProgramShape());
int64 handle = outputs[0].scalar<int64>()(); int64 handle = outputs[0].scalar<int64>()();
return new CompiledXrtComputation(program_shape, handle, session_target); return new XrtExecutable(program_shape, handle, session_target);
} }
const XlaComputation& LocalComputation::computation() const { const XlaComputation& Computation::computation() const { return computation_; }
return computation_;
}
string LocalComputation::GetSerializedProto() const { string Computation::GetSerializedProto() const {
string result; string result;
if (!computation_.proto().SerializeToString(&result)) { if (!computation_.proto().SerializeToString(&result)) {
LOG(ERROR) << "Failed to serialize the HloModuleProto."; LOG(ERROR) << "Failed to serialize the HloModuleProto.";
@ -601,101 +582,103 @@ string LocalComputation::GetSerializedProto() const {
return result; return result;
} }
StatusOr<Shape> LocalComputation::GetReturnValueShape() const { StatusOr<ProgramShape> Computation::GetProgramShape() const {
return swig::GetReturnValueShape(computation_); return computation_.GetProgramShape();
}
StatusOr<Shape> Computation::GetReturnValueShape() const {
TF_ASSIGN_OR_RETURN(ProgramShape shape, computation_.GetProgramShape());
return std::move(*shape.mutable_result());
} }
LocalOp::LocalOp(const XlaOp& op) : op_(op) {} LocalOp::LocalOp(const XlaOp& op) : op_(op) {}
const XlaOp& LocalOp::op() const { return op_; } const XlaOp& LocalOp::op() const { return op_; }
LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) ComputationBuilder::ComputationBuilder(const string& computation_name)
: builder_(computation_name) {} : builder_(computation_name) {}
void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { void ComputationBuilder::SetOpMetadata(const OpMetadata& metadata) {
builder_.SetOpMetadata(metadata); builder_.SetOpMetadata(metadata);
} }
void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } void ComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); }
StatusOr<LocalComputation*> LocalComputationBuilder::Build() { StatusOr<Computation*> ComputationBuilder::Build() {
TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build());
return new LocalComputation(std::move(computation)); return new Computation(std::move(computation));
} }
LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, LocalOp ComputationBuilder::Parameter(int64 parameter_number,
const Shape& shape, const Shape& shape, const string& name) {
const string& name) {
return xla::Parameter(&builder_, parameter_number, shape, name); return xla::Parameter(&builder_, parameter_number, shape, name);
} }
StatusOr<LocalComputation*> LocalComputationBuilder::BuildWithRoot( StatusOr<Computation*> ComputationBuilder::BuildWithRoot(const LocalOp& root) {
const LocalOp& root) {
TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build(root.op())); TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build(root.op()));
return new LocalComputation(std::move(computation)); return new Computation(std::move(computation));
} }
StatusOr<Shape> LocalComputationBuilder::GetShape(const LocalOp& operand) { StatusOr<Shape> ComputationBuilder::GetShape(const LocalOp& operand) {
return builder_.GetShape(operand.op()); return builder_.GetShape(operand.op());
} }
StatusOr<Shape> LocalComputationBuilder::GetReturnValueShape() { StatusOr<Shape> ComputationBuilder::GetReturnValueShape() {
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape()); TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape());
return program_shape.result(); return program_shape.result();
} }
LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { LocalOp ComputationBuilder::Infeed(const Shape& shape) {
return xla::Infeed(&builder_, shape); return xla::Infeed(&builder_, shape);
} }
void LocalComputationBuilder::Outfeed(const LocalOp& operand, void ComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape,
const Shape& shape,
const string& outfeed_config) { const string& outfeed_config) {
xla::Outfeed(operand.op(), shape, outfeed_config); xla::Outfeed(operand.op(), shape, outfeed_config);
} }
LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { LocalOp ComputationBuilder::ConstantLiteral(const Literal& literal) {
return xla::ConstantLiteral(&builder_, literal); return xla::ConstantLiteral(&builder_, literal);
} }
LocalOp LocalComputationBuilder::Iota(PrimitiveType element_type, int64 size) { LocalOp ComputationBuilder::Iota(PrimitiveType element_type, int64 size) {
return xla::Iota(&builder_, element_type, size); return xla::Iota(&builder_, element_type, size);
} }
LocalOp LocalComputationBuilder::BroadcastedIota(const Shape& shape, LocalOp ComputationBuilder::BroadcastedIota(const Shape& shape,
int64 dimension) { int64 dimension) {
return xla::Iota(&builder_, shape, dimension); return xla::Iota(&builder_, shape, dimension);
} }
LocalOp LocalComputationBuilder::Broadcast( LocalOp ComputationBuilder::Broadcast(const LocalOp& operand,
const LocalOp& operand, absl::Span<const int64> broadcast_sizes) { absl::Span<const int64> broadcast_sizes) {
return xla::Broadcast(operand.op(), broadcast_sizes); return xla::Broadcast(operand.op(), broadcast_sizes);
} }
LocalOp LocalComputationBuilder::BroadcastInDim( LocalOp ComputationBuilder::BroadcastInDim(
const LocalOp& operand, absl::Span<const int64> out_dim_sizes, const LocalOp& operand, absl::Span<const int64> out_dim_sizes,
absl::Span<const int64> broadcast_dimensions) { absl::Span<const int64> broadcast_dimensions) {
return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions); return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions);
} }
LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, LocalOp ComputationBuilder::Pad(const LocalOp& operand,
const LocalOp& padding_value, const LocalOp& padding_value,
const PaddingConfig& padding_config) { const PaddingConfig& padding_config) {
return xla::Pad(operand.op(), padding_value.op(), padding_config); return xla::Pad(operand.op(), padding_value.op(), padding_config);
} }
LocalOp LocalComputationBuilder::Reshape(const LocalOp& operand, LocalOp ComputationBuilder::Reshape(const LocalOp& operand,
absl::Span<const int64> dimensions, absl::Span<const int64> dimensions,
absl::Span<const int64> new_sizes) { absl::Span<const int64> new_sizes) {
return xla::Reshape(operand.op(), dimensions, new_sizes); return xla::Reshape(operand.op(), dimensions, new_sizes);
} }
LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand, LocalOp ComputationBuilder::Collapse(const LocalOp& operand,
absl::Span<const int64> dimensions) { absl::Span<const int64> dimensions) {
return xla::Collapse(operand.op(), dimensions); return xla::Collapse(operand.op(), dimensions);
} }
LocalOp LocalComputationBuilder::AllToAll( LocalOp ComputationBuilder::AllToAll(
const LocalOp& operand, int64 split_dimension, int64 concat_dimension, const LocalOp& operand, int64 split_dimension, int64 concat_dimension,
int64 split_count, absl::Span<const ReplicaGroup> replica_groups) { int64 split_count, absl::Span<const ReplicaGroup> replica_groups) {
std::vector<ReplicaGroup> rg(replica_groups.size()); std::vector<ReplicaGroup> rg(replica_groups.size());
@ -706,38 +689,37 @@ LocalOp LocalComputationBuilder::AllToAll(
split_count, rg); split_count, rg);
} }
LocalOp LocalComputationBuilder::CrossReplicaSum( LocalOp ComputationBuilder::CrossReplicaSum(
const LocalOp& operand, absl::Span<const ReplicaGroup> replica_groups) { const LocalOp& operand, absl::Span<const ReplicaGroup> replica_groups) {
return xla::CrossReplicaSum(operand.op(), replica_groups); return xla::CrossReplicaSum(operand.op(), replica_groups);
} }
LocalOp LocalComputationBuilder::Slice(const LocalOp& operand, LocalOp ComputationBuilder::Slice(const LocalOp& operand,
absl::Span<const int64> start_indices, absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices, absl::Span<const int64> limit_indices,
absl::Span<const int64> strides) { absl::Span<const int64> strides) {
return xla::Slice(operand.op(), start_indices, limit_indices, strides); return xla::Slice(operand.op(), start_indices, limit_indices, strides);
} }
LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, LocalOp ComputationBuilder::SliceInDim(const LocalOp& operand,
int64 start_index, int64 start_index, int64 limit_index,
int64 limit_index, int64 stride, int64 stride, int64 dimno) {
int64 dimno) {
return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno); return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno);
} }
LocalOp LocalComputationBuilder::DynamicSlice( LocalOp ComputationBuilder::DynamicSlice(const LocalOp& operand,
const LocalOp& operand, const LocalOp& start_indices, const LocalOp& start_indices,
absl::Span<const int64> slice_sizes) { absl::Span<const int64> slice_sizes) {
return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes); return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes);
} }
LocalOp LocalComputationBuilder::DynamicUpdateSlice( LocalOp ComputationBuilder::DynamicUpdateSlice(const LocalOp& operand,
const LocalOp& operand, const LocalOp& update, const LocalOp& update,
const LocalOp& start_indices) { const LocalOp& start_indices) {
return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op()); return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op());
} }
LocalOp LocalComputationBuilder::ConcatInDim(absl::Span<const LocalOp> operands, LocalOp ComputationBuilder::ConcatInDim(absl::Span<const LocalOp> operands,
int64 dimension) { int64 dimension) {
std::vector<XlaOp> xla_ops; std::vector<XlaOp> xla_ops;
xla_ops.reserve(operands.size()); xla_ops.reserve(operands.size());
@ -747,18 +729,18 @@ LocalOp LocalComputationBuilder::ConcatInDim(absl::Span<const LocalOp> operands,
return xla::ConcatInDim(&builder_, xla_ops, dimension); return xla::ConcatInDim(&builder_, xla_ops, dimension);
} }
LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( LocalOp ComputationBuilder::SelectAndScatterWithGeneralPadding(
const LocalOp& operand, const LocalComputation& select, const LocalOp& operand, const Computation& select,
absl::Span<const int64> window_dimensions, absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding, const LocalOp& source, absl::Span<const std::pair<int64, int64>> padding, const LocalOp& source,
const LocalOp& init_value, const LocalComputation& scatter) { const LocalOp& init_value, const Computation& scatter) {
return xla::SelectAndScatterWithGeneralPadding( return xla::SelectAndScatterWithGeneralPadding(
operand.op(), select.computation(), window_dimensions, window_strides, operand.op(), select.computation(), window_dimensions, window_strides,
padding, source.op(), init_value.op(), scatter.computation()); padding, source.op(), init_value.op(), scatter.computation());
} }
LocalOp LocalComputationBuilder::Tuple(absl::Span<const LocalOp> elements) { LocalOp ComputationBuilder::Tuple(absl::Span<const LocalOp> elements) {
std::vector<XlaOp> xla_ops; std::vector<XlaOp> xla_ops;
xla_ops.reserve(elements.size()); xla_ops.reserve(elements.size());
for (const auto& op : elements) { for (const auto& op : elements) {
@ -768,22 +750,22 @@ LocalOp LocalComputationBuilder::Tuple(absl::Span<const LocalOp> elements) {
return xla::Tuple(&builder_, xla_ops); return xla::Tuple(&builder_, xla_ops);
} }
LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, LocalOp ComputationBuilder::GetTupleElement(const LocalOp& tuple_data,
int64 index) { int64 index) {
return xla::GetTupleElement(tuple_data.op(), index); return xla::GetTupleElement(tuple_data.op(), index);
} }
LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { LocalOp ComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) {
return xla::Dot(lhs.op(), rhs.op()); return xla::Dot(lhs.op(), rhs.op());
} }
LocalOp LocalComputationBuilder::DotGeneral( LocalOp ComputationBuilder::DotGeneral(
const LocalOp& lhs, const LocalOp& rhs, const LocalOp& lhs, const LocalOp& rhs,
const DotDimensionNumbers& dimension_numbers) { const DotDimensionNumbers& dimension_numbers) {
return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers); return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers);
} }
LocalOp LocalComputationBuilder::ConvGeneralDilated( LocalOp ComputationBuilder::ConvGeneralDilated(
const LocalOp& lhs, const LocalOp& rhs, const LocalOp& lhs, const LocalOp& rhs,
absl::Span<const int64> window_strides, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding, absl::Span<const std::pair<int64, int64>> padding,
@ -795,17 +777,17 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated(
feature_group_count); feature_group_count);
} }
LocalOp LocalComputationBuilder::ConvertElementType( LocalOp ComputationBuilder::ConvertElementType(const LocalOp& operand,
const LocalOp& operand, PrimitiveType new_element_type) { PrimitiveType new_element_type) {
return xla::ConvertElementType(operand.op(), new_element_type); return xla::ConvertElementType(operand.op(), new_element_type);
} }
LocalOp LocalComputationBuilder::BitcastConvertType( LocalOp ComputationBuilder::BitcastConvertType(const LocalOp& operand,
const LocalOp& operand, PrimitiveType new_element_type) { PrimitiveType new_element_type) {
return xla::BitcastConvertType(operand.op(), new_element_type); return xla::BitcastConvertType(operand.op(), new_element_type);
} }
LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, LocalOp ComputationBuilder::Call(const Computation& local_computation,
absl::Span<const LocalOp> operands) { absl::Span<const LocalOp> operands) {
std::vector<XlaOp> xla_ops; std::vector<XlaOp> xla_ops;
xla_ops.reserve(operands.size()); xla_ops.reserve(operands.size());
@ -815,7 +797,7 @@ LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation,
return xla::Call(&builder_, local_computation.computation(), xla_ops); return xla::Call(&builder_, local_computation.computation(), xla_ops);
} }
LocalOp LocalComputationBuilder::CustomCall( LocalOp ComputationBuilder::CustomCall(
const string& call_target_name, absl::Span<const LocalOp> operands, const string& call_target_name, absl::Span<const LocalOp> operands,
const Shape& shape_with_layout, const Shape& shape_with_layout,
const std::vector<Shape>& operand_shapes_with_layout, const std::vector<Shape>& operand_shapes_with_layout,
@ -830,18 +812,18 @@ LocalOp LocalComputationBuilder::CustomCall(
operand_shapes_with_layout, opaque); operand_shapes_with_layout, opaque);
} }
LocalOp LocalComputationBuilder::Transpose( LocalOp ComputationBuilder::Transpose(const LocalOp& operand,
const LocalOp& operand, absl::Span<const int64> permutation) { absl::Span<const int64> permutation) {
return xla::Transpose(operand.op(), permutation); return xla::Transpose(operand.op(), permutation);
} }
LocalOp LocalComputationBuilder::Rev(const LocalOp& operand, LocalOp ComputationBuilder::Rev(const LocalOp& operand,
absl::Span<const int64> dimensions) { absl::Span<const int64> dimensions) {
return xla::Rev(operand.op(), dimensions); return xla::Rev(operand.op(), dimensions);
} }
LocalOp LocalComputationBuilder::Map(absl::Span<const LocalOp> operands, LocalOp ComputationBuilder::Map(absl::Span<const LocalOp> operands,
const LocalComputation& local_computation, const Computation& local_computation,
absl::Span<const int64> dimensions) { absl::Span<const int64> dimensions) {
std::vector<XlaOp> xla_ops; std::vector<XlaOp> xla_ops;
xla_ops.reserve(operands.size()); xla_ops.reserve(operands.size());
@ -853,17 +835,17 @@ LocalOp LocalComputationBuilder::Map(absl::Span<const LocalOp> operands,
dimensions); dimensions);
} }
LocalOp LocalComputationBuilder::Reduce( LocalOp ComputationBuilder::Reduce(
const LocalOp& operand, const LocalOp& init_value, const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation, const Computation& local_computation,
absl::Span<const int64> dimensions_to_reduce) { absl::Span<const int64> dimensions_to_reduce) {
return xla::Reduce(operand.op(), init_value.op(), return xla::Reduce(operand.op(), init_value.op(),
local_computation.computation(), dimensions_to_reduce); local_computation.computation(), dimensions_to_reduce);
} }
LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( LocalOp ComputationBuilder::ReduceWindowWithGeneralPadding(
const LocalOp& operand, const LocalOp& init_value, const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation, const Computation& local_computation,
absl::Span<const int64> window_dimensions, absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides, absl::Span<const int64> window_strides,
absl::Span<const int64> base_dilations, absl::Span<const int64> base_dilations,
@ -875,51 +857,50 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
padding); padding);
} }
LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, LocalOp ComputationBuilder::RngNormal(const LocalOp& mu, const LocalOp& sigma,
const LocalOp& sigma,
const Shape& shape) { const Shape& shape) {
return xla::RngNormal(mu.op(), sigma.op(), shape); return xla::RngNormal(mu.op(), sigma.op(), shape);
} }
LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, LocalOp ComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b,
const Shape& shape) { const Shape& shape) {
return xla::RngUniform(a.op(), b.op(), shape); return xla::RngUniform(a.op(), b.op(), shape);
} }
LocalOp LocalComputationBuilder::While(const LocalComputation& condition, LocalOp ComputationBuilder::While(const Computation& condition,
const LocalComputation& body, const Computation& body,
const LocalOp& init) { const LocalOp& init) {
return xla::While(condition.computation(), body.computation(), init.op()); return xla::While(condition.computation(), body.computation(), init.op());
} }
LocalOp LocalComputationBuilder::Conditional( LocalOp ComputationBuilder::Conditional(const LocalOp& predicate,
const LocalOp& predicate, const LocalOp& true_operand, const LocalOp& true_operand,
const LocalComputation& true_computation, const LocalOp& false_operand, const Computation& true_computation,
const LocalComputation& false_computation) { const LocalOp& false_operand,
const Computation& false_computation) {
return xla::Conditional(predicate.op(), true_operand.op(), return xla::Conditional(predicate.op(), true_operand.op(),
true_computation.computation(), false_operand.op(), true_computation.computation(), false_operand.op(),
false_computation.computation()); false_computation.computation());
} }
StatusOr<bool> LocalComputationBuilder::IsConstant(const LocalOp& operand) { StatusOr<bool> ComputationBuilder::IsConstant(const LocalOp& operand) {
return builder_.IsConstant(operand.op()); return builder_.IsConstant(operand.op());
} }
LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { LocalOp ComputationBuilder::Sort(const LocalOp& operand, int64 dimension) {
return xla::Sort(operand.op(), {}, dimension); return xla::Sort(operand.op(), {}, dimension);
} }
LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, LocalOp ComputationBuilder::SortKeyVal(const LocalOp& keys,
const LocalOp& values, const LocalOp& values, int64 dimension) {
int64 dimension) {
return xla::Sort(keys.op(), {values.op()}, dimension); return xla::Sort(keys.op(), {values.op()}, dimension);
} }
LocalOp LocalComputationBuilder::Cholesky(const LocalOp& a) { LocalOp ComputationBuilder::Cholesky(const LocalOp& a) {
return xla::Cholesky(a.op()); return xla::Cholesky(a.op());
} }
LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) { LocalOp ComputationBuilder::QR(const LocalOp& a, bool full_matrices) {
XlaBuilder* builder = a.op().builder(); XlaBuilder* builder = a.op().builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices)); TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices));
@ -927,8 +908,7 @@ LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) {
}); });
} }
LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a, LocalOp ComputationBuilder::TriangularSolve(const LocalOp& a, const LocalOp& b,
const LocalOp& b,
bool left_side, bool lower, bool left_side, bool lower,
bool unit_diagonal, bool unit_diagonal,
int transpose_a) { int transpose_a) {
@ -937,7 +917,7 @@ LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a,
xla::TriangularSolveOptions::Transpose(transpose_a)); xla::TriangularSolveOptions::Transpose(transpose_a));
} }
LocalOp LocalComputationBuilder::Gather( LocalOp ComputationBuilder::Gather(
const LocalOp& input, const LocalOp& start_indices, const LocalOp& input, const LocalOp& start_indices,
const GatherDimensionNumbers& dimension_numbers, const GatherDimensionNumbers& dimension_numbers,
absl::Span<const int64> slice_sizes) { absl::Span<const int64> slice_sizes) {
@ -945,23 +925,23 @@ LocalOp LocalComputationBuilder::Gather(
slice_sizes); slice_sizes);
} }
LocalOp LocalComputationBuilder::Scatter( LocalOp ComputationBuilder::Scatter(
const LocalOp& input, const LocalOp& scatter_indices, const LocalOp& input, const LocalOp& scatter_indices,
const LocalOp& updates, const LocalComputation& update_computation, const LocalOp& updates, const Computation& update_computation,
const ScatterDimensionNumbers& dimension_numbers) { const ScatterDimensionNumbers& dimension_numbers) {
return xla::Scatter(input.op(), scatter_indices.op(), updates.op(), return xla::Scatter(input.op(), scatter_indices.op(), updates.op(),
update_computation.computation(), dimension_numbers); update_computation.computation(), dimension_numbers);
} }
StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph( StatusOr<Computation*> ComputationBuilder::BuildConstantSubGraph(
const LocalOp& operand) { const LocalOp& operand) {
TF_ASSIGN_OR_RETURN(XlaComputation computation, TF_ASSIGN_OR_RETURN(XlaComputation computation,
builder_.BuildConstantSubGraph(operand.op())); builder_.BuildConstantSubGraph(operand.op()));
return new LocalComputation(std::move(computation)); return new Computation(std::move(computation));
} }
#define _FORWARD(method_name, return_sig, args_sig, args) \ #define _FORWARD(method_name, return_sig, args_sig, args) \
return_sig LocalComputationBuilder::method_name args_sig { \ return_sig ComputationBuilder::method_name args_sig { \
return xla::method_name args; \ return xla::method_name args; \
} }
@ -1051,64 +1031,11 @@ void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) {
void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; } void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; }
void DeleteCompiledLocalComputation(CompiledLocalComputation* computation) { void DeleteLocalExecutable(LocalExecutable* computation) { delete computation; }
delete computation;
}
void DeleteCompiledXrtComputation(CompiledXrtComputation* computation) { void DeleteXrtExecutable(XrtExecutable* computation) { delete computation; }
delete computation;
}
void DeleteLocalComputation(LocalComputation* computation) { void DeleteComputation(Computation* computation) { delete computation; }
delete computation;
}
StatusOr<LocalShapedBufferTuple*> DestructureLocalShapedBufferTuple(
LocalShapedBuffer* local_shaped_buffer) {
const Shape tuple_shape = local_shaped_buffer->shape();
if (!tuple_shape.IsTuple()) {
return InvalidArgument(
"Attemped to destructure a LocalShapedBuffer that did not have a tuple "
"shape; shape: %s",
ShapeUtil::HumanString(tuple_shape));
}
DeviceMemoryAllocator* allocator =
local_shaped_buffer->shaped_buffer()->memory_allocator();
ShapedBuffer tuple_buffer = local_shaped_buffer->Release();
// Extract some metadata we use to construct scoped buffers.
const se::Platform* platform = tuple_buffer.platform();
int device_ordinal = tuple_buffer.device_ordinal();
ShapeTree<se::DeviceMemoryBase>& shape_tree = tuple_buffer.buffers();
std::vector<LocalShapedBuffer*> results;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
// Create a shaped buffer for this destructured tuple element.
const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i});
VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape;
ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal);
ShapeUtil::ForEachSubshape(
subshape, [&](const Shape& s, const ShapeIndex& index) {
ShapeIndex original(index);
original.push_front(i);
se::DeviceMemoryBase* device_memory =
shape_tree.mutable_element(original);
shaped_buffer.set_buffer(*device_memory, index);
*device_memory = se::DeviceMemoryBase();
});
VLOG(3) << "Completed tuple element: " << i;
results.push_back(new LocalShapedBuffer(
ScopedShapedBuffer(std::move(shaped_buffer), allocator)));
}
// Deallocate the root buffer.
se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer();
TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer));
return new LocalShapedBufferTuple(std::move(results));
}
StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple( StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
XrtAllocation* allocation, const string& session_target) { XrtAllocation* allocation, const string& session_target) {

View File

@ -35,42 +35,42 @@ limitations under the License.
namespace xla { namespace xla {
namespace swig { namespace swig {
// Initializes the number of replicas that XLA will be initialized with (when
// first obtaining a handle to the local XLA service). If this is called after
// the handle to the local XLA service has been established, then an error is
// returned.
Status InitializeReplicaCount(int replica_count);
// Initializes the platform name that XLA will be initialized with (when
// first obtaining a handle to the local XLA service). If this is called after
// the handle to the local XLA service has been established, then an error is
// returned.
Status InitializePlatformName(const string& platform_name);
// Returns the replica count that is currently set, regardless of whether the
// local XLA service has been instantiated yet or not.
int GetReplicaCount();
// Registers a 'fn_capsule' as a CPU custom call target. // Registers a 'fn_capsule' as a CPU custom call target.
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name // 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
// "xla._CPU_CUSTOM_CALL_TARGET". // "xla._CPU_CUSTOM_CALL_TARGET".
Status RegisterCpuCustomCallTarget(const string& name, PyObject* fn_capsule); Status RegisterCpuCustomCallTarget(const string& name, PyObject* fn_capsule);
// Wraps the local client's infeed-transfer function. // Wrapper around an xla::LocalClient.
// class LocalClient {
// The default device ordinal (0) is used. public:
Status TransferToInfeedLocal(const Literal& literal); // Initializes a local XLA client for `platform_name`. Returns an error if no
/// such platform exists, or if the platform has no visible devices.
static StatusOr<LocalClient> Get(const string& platform_name);
// Transfers the given literal to the infeed of the given replica. // Copyable and moveable; the class is just a wrapper around a
// // xla::LocalClient pointer for convenient SWIG wrapping.
// The replica number is resolved to an appropriate device ordinal.
Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number);
// Transfers a literal of the given shape from the outfeed of the given replica. // Returns the number of devices known to the XLA client.
// int DeviceCount() const;
// The replica number is resolved to an appropriate device ordinal.
StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape, // Wraps the local client's infeed-transfer function.
int replica_number); //
// The default device ordinal (0) is used.
Status TransferToInfeed(const Literal& literal, int device_ordinal);
// Transfers a literal of the given shape from the outfeed of the given
// replica.
StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_ordinal);
xla::LocalClient* client() const { return client_; }
private:
LocalClient(xla::LocalClient* client);
xla::LocalClient* client_;
};
class LocalShapedBufferTuple;
// Represents a reference to literals that live in a device-allocated buffer via // Represents a reference to literals that live in a device-allocated buffer via
// XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a // XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a
@ -79,9 +79,9 @@ class LocalShapedBuffer {
public: public:
static StatusOr<LocalShapedBuffer*> FromLiteral( static StatusOr<LocalShapedBuffer*> FromLiteral(
const Literal& argument, const absl::optional<Shape>& shape_with_layout, const Literal& argument, const absl::optional<Shape>& shape_with_layout,
int replica_number); const LocalClient& client, int device_ordinal);
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, xla::LocalClient* client);
StatusOr<Literal> ToLiteral() const; StatusOr<Literal> ToLiteral() const;
const Shape& shape() const; const Shape& shape() const;
const ScopedShapedBuffer* shaped_buffer() const; const ScopedShapedBuffer* shaped_buffer() const;
@ -90,8 +90,13 @@ class LocalShapedBuffer {
// analogous to std::unique_ptr::release(). // analogous to std::unique_ptr::release().
ShapedBuffer Release(); ShapedBuffer Release();
// Destructures a tuple-valued LocalShapedBuffer into its constitutent
// elements in LocalShapedBufferTuple form.
StatusOr<LocalShapedBufferTuple*> DestructureTuple();
private: private:
ScopedShapedBuffer shaped_buffer_; ScopedShapedBuffer shaped_buffer_;
xla::LocalClient* client_;
}; };
// Result of a tuple destructuring operation on a LocalShapedBuffer -- this // Result of a tuple destructuring operation on a LocalShapedBuffer -- this
@ -117,11 +122,6 @@ class LocalShapedBufferTuple {
std::vector<LocalShapedBuffer*> elements_; std::vector<LocalShapedBuffer*> elements_;
}; };
// Destructures a tuple-valued LocalShapedBuffer into its constitutent elements
// in LocalShapedBufferTuple form.
StatusOr<LocalShapedBufferTuple*> DestructureLocalShapedBufferTuple(
LocalShapedBuffer* local_shaped_buffer);
// Represents a reference to literals that live in a device-allocated buffer via // Represents a reference to literals that live in a device-allocated buffer via
// XRT. Specifically, wraps an int64 handle produced by running the allocation // XRT. Specifically, wraps an int64 handle produced by running the allocation
// graph, and an XLA shape to track the referent's shape. // graph, and an XLA shape to track the referent's shape.
@ -176,14 +176,19 @@ StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
// Represents a compiled computation that can be executed given handles to // Represents a compiled computation that can be executed given handles to
// device-allocated literals. Specifically, wraps an XLA LocalExecutable. // device-allocated literals. Specifically, wraps an XLA LocalExecutable.
class CompiledLocalComputation { class LocalExecutable {
public: public:
CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable); LocalExecutable(std::unique_ptr<xla::LocalExecutable> executable,
xla::DeviceAssignment device_assignment,
xla::LocalClient* client);
int num_replicas() const { int num_replicas() const {
return executable_->build_options().num_replicas(); return executable_->build_options().num_replicas();
} }
// Returns the device ordinals to which each replica is assigned.
std::vector<int> DeviceOrdinals() const;
StatusOr<LocalShapedBuffer*> Execute( StatusOr<LocalShapedBuffer*> Execute(
absl::Span<LocalShapedBuffer* const> argument_handles); absl::Span<LocalShapedBuffer* const> argument_handles);
@ -194,18 +199,22 @@ class CompiledLocalComputation {
absl::Span<const std::vector<LocalShapedBuffer*> > argument_handles); absl::Span<const std::vector<LocalShapedBuffer*> > argument_handles);
private: private:
std::unique_ptr<LocalExecutable> executable_; const std::unique_ptr<xla::LocalExecutable> executable_;
const xla::DeviceAssignment device_assignment_;
xla::LocalClient* const client_;
}; };
// Represents a compiled computation that can be executed given handles to // Represents a compiled computation that can be executed given handles to
// device-allocated literals. Specifically, wraps an XRT computation handle. // device-allocated literals. Specifically, wraps an XRT computation handle.
class CompiledXrtComputation { class XrtExecutable {
public: public:
// Accepts a `session_target` argument, used in constructing the // Accepts a `session_target` argument, used in constructing the
// `tensorflow::ClientSession` instance in which the execution graph is run. // `tensorflow::ClientSession` instance in which the execution graph is run.
CompiledXrtComputation(const ProgramShape& program_shape, int64 handle, XrtExecutable(const ProgramShape& program_shape, int64 handle,
const string& session_target); const string& session_target);
~CompiledXrtComputation(); ~XrtExecutable();
std::vector<int> DeviceOrdinals() const { return {0}; }
StatusOr<XrtAllocation*> Execute( StatusOr<XrtAllocation*> Execute(
absl::Span<XrtAllocation* const> argument_handles); absl::Span<XrtAllocation* const> argument_handles);
@ -219,21 +228,21 @@ class CompiledXrtComputation {
const string session_target_; const string session_target_;
}; };
// Wraps a XlaComputation produced by a LocalComputationBuilder. The // Wraps a XlaComputation produced by a ComputationBuilder. The
// Compile method compiles the computation to a (local) executable via // Compile method compiles the computation to a (local) executable via
// the client library's local client. This class is intended to be // the client library's local client. This class is intended to be
// made available to Python via SWIG. // made available to Python via SWIG.
class LocalComputation { class Computation {
public: public:
LocalComputation(XlaComputation computation); Computation(XlaComputation computation);
StatusOr<CompiledLocalComputation*> Compile( StatusOr<LocalExecutable*> Compile(
const std::vector<Shape>& argument_shapes, const std::vector<Shape>& argument_shapes,
const ExecutableBuildOptions* build_options); const ExecutableBuildOptions* build_options, const LocalClient& client);
// Accepts a `session_target` argument, used in constructing the // Accepts a `session_target` argument, used in constructing the
// `tensorflow::ClientSession` instance in which the compilation graph is run. // `tensorflow::ClientSession` instance in which the compilation graph is run.
StatusOr<CompiledXrtComputation*> CompileForXrt( StatusOr<XrtExecutable*> CompileForXrt(
const std::vector<Shape>& argument_shapes, const string& session_target); const std::vector<Shape>& argument_shapes, const string& session_target);
const XlaComputation& computation() const; const XlaComputation& computation() const;
@ -243,6 +252,9 @@ class LocalComputation {
// string on failure. // string on failure.
string GetSerializedProto() const; string GetSerializedProto() const;
// Returns the program shape for this computation.
StatusOr<ProgramShape> GetProgramShape() const;
// Returns the return-value shape for this computation. // Returns the return-value shape for this computation.
StatusOr<Shape> GetReturnValueShape() const; StatusOr<Shape> GetReturnValueShape() const;
@ -250,7 +262,7 @@ class LocalComputation {
XlaComputation computation_; XlaComputation computation_;
}; };
// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended // Wraps a XlaOp produced by a ComputationBuilder. This class is intended
// to be made available to Python via SWIG. // to be made available to Python via SWIG.
class LocalOp { class LocalOp {
public: public:
@ -267,20 +279,20 @@ class LocalOp {
// Python. // Python.
// - Set up the underlying builder to use the client library's // - Set up the underlying builder to use the client library's
// LocalClient. // LocalClient.
// - Wrap Computations in LocalComputations for Python access. // - Wrap Computations in Computations for Python access.
// - Correspondingly unwrap incoming LocalComputations. // - Correspondingly unwrap incoming Computations.
class LocalComputationBuilder { class ComputationBuilder {
public: public:
LocalComputationBuilder(const string& computation_name); ComputationBuilder(const string& computation_name);
void SetOpMetadata(const OpMetadata& metadata); void SetOpMetadata(const OpMetadata& metadata);
void ClearOpMetadata(); void ClearOpMetadata();
// Returns an owned LocalComputation to the caller on success. // Returns an owned Computation to the caller on success.
StatusOr<LocalComputation*> Build(); StatusOr<Computation*> Build();
// Returns an owned LocalComputation to the caller on success with given root. // Returns an owned Computation to the caller on success with given root.
StatusOr<LocalComputation*> BuildWithRoot(const LocalOp& root); StatusOr<Computation*> BuildWithRoot(const LocalOp& root);
LocalOp Parameter(int64 parameter_number, const Shape& shape, LocalOp Parameter(int64 parameter_number, const Shape& shape,
const string& name); const string& name);
@ -339,11 +351,11 @@ class LocalComputationBuilder {
LocalOp ConcatInDim(absl::Span<const LocalOp> operands, int64 dimension); LocalOp ConcatInDim(absl::Span<const LocalOp> operands, int64 dimension);
LocalOp SelectAndScatterWithGeneralPadding( LocalOp SelectAndScatterWithGeneralPadding(
const LocalOp& operand, const LocalComputation& select, const LocalOp& operand, const Computation& select,
absl::Span<const int64> window_dimensions, absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64> > padding, const LocalOp& source, absl::Span<const std::pair<int64, int64> > padding, const LocalOp& source,
const LocalOp& init_value, const LocalComputation& scatter); const LocalOp& init_value, const Computation& scatter);
LocalOp Tuple(absl::Span<const LocalOp> elements); LocalOp Tuple(absl::Span<const LocalOp> elements);
@ -369,7 +381,7 @@ class LocalComputationBuilder {
LocalOp BitcastConvertType(const LocalOp& operand, LocalOp BitcastConvertType(const LocalOp& operand,
PrimitiveType new_element_type); PrimitiveType new_element_type);
LocalOp Call(const LocalComputation& local_computation, LocalOp Call(const Computation& local_computation,
absl::Span<const LocalOp> operands); absl::Span<const LocalOp> operands);
LocalOp CustomCall(const string& call_target_name, LocalOp CustomCall(const string& call_target_name,
@ -384,16 +396,16 @@ class LocalComputationBuilder {
LocalOp Rev(const LocalOp& operand, absl::Span<const int64> dimensions); LocalOp Rev(const LocalOp& operand, absl::Span<const int64> dimensions);
LocalOp Map(absl::Span<const LocalOp> operands, LocalOp Map(absl::Span<const LocalOp> operands,
const LocalComputation& local_computation, const Computation& local_computation,
absl::Span<const int64> dimensions); absl::Span<const int64> dimensions);
LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation, const Computation& local_computation,
absl::Span<const int64> dimensions_to_reduce); absl::Span<const int64> dimensions_to_reduce);
LocalOp ReduceWindowWithGeneralPadding( LocalOp ReduceWindowWithGeneralPadding(
const LocalOp& operand, const LocalOp& init_value, const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation, const Computation& local_computation,
absl::Span<const int64> window_dimensions, absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides, absl::Span<const int64> window_strides,
absl::Span<const int64> base_dilations, absl::Span<const int64> base_dilations,
@ -405,13 +417,13 @@ class LocalComputationBuilder {
LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape);
LocalOp While(const LocalComputation& condition, const LocalComputation& body, LocalOp While(const Computation& condition, const Computation& body,
const LocalOp& init); const LocalOp& init);
LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand,
const LocalComputation& true_computation, const Computation& true_computation,
const LocalOp& false_operand, const LocalOp& false_operand,
const LocalComputation& false_computation); const Computation& false_computation);
StatusOr<bool> IsConstant(const LocalOp& operand); StatusOr<bool> IsConstant(const LocalOp& operand);
@ -435,11 +447,10 @@ class LocalComputationBuilder {
absl::Span<const int64> slice_sizes); absl::Span<const int64> slice_sizes);
LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices, LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices,
const LocalOp& updates, const LocalOp& updates, const Computation& update_computation,
const LocalComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers); const ScatterDimensionNumbers& dimension_numbers);
StatusOr<LocalComputation*> BuildConstantSubGraph(const LocalOp& operand); StatusOr<Computation*> BuildConstantSubGraph(const LocalOp& operand);
#define _FORWARD(method_name, return_sig, args_sig) \ #define _FORWARD(method_name, return_sig, args_sig) \
return_sig method_name args_sig; return_sig method_name args_sig;
@ -529,9 +540,9 @@ class LocalComputationBuilder {
// Functions for freeing resources from the Python side. // Functions for freeing resources from the Python side.
void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer); void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer);
void DeleteXrtAllocation(XrtAllocation* allocation); void DeleteXrtAllocation(XrtAllocation* allocation);
void DeleteCompiledLocalComputation(CompiledLocalComputation* computation); void DeleteLocalExecutable(LocalExecutable* computation);
void DeleteCompiledXrtComputation(CompiledXrtComputation* computation); void DeleteXrtExecutable(XrtExecutable* computation);
void DeleteLocalComputation(LocalComputation* computation); void DeleteComputation(Computation* computation);
} // namespace swig } // namespace swig
} // namespace xla } // namespace xla

View File

@ -23,11 +23,13 @@ limitations under the License.
// C++ Python // C++ Python
// -------------------------------------+--------------------------------------- // -------------------------------------+---------------------------------------
// Span<int64> <- sequence of int // Span<int64> <- sequence of int
// vector<int> -> sequence of int
// Span<LocalOp> <- sequence of LocalOp // Span<LocalOp> <- sequence of LocalOp
// Literal <-> (nested tuple of) numpy ndarray // Literal <-> (nested tuple of) numpy ndarray
// std::vector<Literal> <- sequence of (nested tuple of) ndarray // std::vector<Literal> <- sequence of (nested tuple of) ndarray
// Shape -> pair holding (dtype, dimensions) // Shape -> pair holding (dtype, dimensions)
// <- object duck-typed as xla_client.Shape // <- object duck-typed as xla_client.Shape
// ProgramShape -> pair of ([arg_shapes], ret_shape)
// std::vector<Shape> <- sequence of xla_client.Shape objects // std::vector<Shape> <- sequence of xla_client.Shape objects
// PrimitiveType <- int // PrimitiveType <- int
// Span<pair<int64, in64>> <- sequence of int pairs // Span<pair<int64, in64>> <- sequence of int pairs
@ -97,7 +99,7 @@ limitations under the License.
// wrapped in a Python class (xla_client.Shape) so as not to expose // wrapped in a Python class (xla_client.Shape) so as not to expose
// the raw pair externally. // the raw pair externally.
// //
// Other SWIG object wrappers (e.g. of LocalComputation) are further // Other SWIG object wrappers (e.g. of Computation) are further
// wrapped by xla_client in order to set up a custom destructor that // wrapped by xla_client in order to set up a custom destructor that
// triggers memory deallocation on the C++ side. // triggers memory deallocation on the C++ side.
@ -214,6 +216,15 @@ tensorflow::ImportNumpy();
// Basic types // Basic types
%typemap(out) std::vector<int> {
PyObject* out = PyList_New($1.size());
for (int i = 0; i < $1.size(); ++i) {
PyList_SET_ITEM(out, i, PyInt_FromLong($1[i]));
}
$result = out;
}
%typemap(out) StatusOr<bool> { %typemap(out) StatusOr<bool> {
if ($1.ok()) { if ($1.ok()) {
$result = PyBool_FromLong($1.ConsumeValueOrDie()); $result = PyBool_FromLong($1.ConsumeValueOrDie());
@ -287,12 +298,12 @@ tensorflow::ImportNumpy();
// Computation and buffer/allocation types // Computation and buffer/allocation types
%typemap(out) StatusOr<xla::swig::CompiledLocalComputation*> { %typemap(out) StatusOr<xla::swig::LocalClient> {
if ($1.ok()) { if ($1.ok()) {
auto* value = $1.ValueOrDie(); xla::swig::LocalClient value = $1.ValueOrDie();
{ {
auto* $1 = value; auto $1 = value;
$typemap(out, xla::swig::CompiledLocalComputation*) $typemap(out, xla::swig::LocalClient)
} }
} else { } else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
@ -300,12 +311,25 @@ tensorflow::ImportNumpy();
} }
} }
%typemap(out) StatusOr<xla::swig::CompiledXrtComputation*> { %typemap(out) StatusOr<xla::swig::LocalExecutable*> {
if ($1.ok()) { if ($1.ok()) {
auto* value = $1.ValueOrDie(); auto* value = $1.ValueOrDie();
{ {
auto* $1 = value; auto* $1 = value;
$typemap(out, xla::swig::CompiledXrtComputation*) $typemap(out, xla::swig::LocalExecutable*)
}
} else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
SWIG_fail;
}
}
%typemap(out) StatusOr<xla::swig::XrtExecutable*> {
if ($1.ok()) {
auto* value = $1.ValueOrDie();
{
auto* $1 = value;
$typemap(out, xla::swig::XrtExecutable*)
} }
} else { } else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
@ -365,12 +389,12 @@ tensorflow::ImportNumpy();
} }
} }
%typemap(out) StatusOr<xla::swig::LocalComputation*> { %typemap(out) StatusOr<xla::swig::Computation*> {
if ($1.ok()) { if ($1.ok()) {
auto* value = $1.ValueOrDie(); auto* value = $1.ValueOrDie();
{ {
auto* $1 = value; auto* $1 = value;
$typemap(out, xla::swig::LocalComputation*) $typemap(out, xla::swig::Computation*)
} }
} else { } else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
@ -519,18 +543,30 @@ tensorflow::ImportNumpy();
// Shape // Shape
%typemap(out) const Shape& { %typemap(out) const Shape& {
$result = numpy::PyShapeInfoFromXlaShape(*$1); $result = numpy::PyShapeInfoFromXlaShape(*$1).release();
} }
%typemap(out) StatusOr<Shape> { %typemap(out) StatusOr<Shape> {
if ($1.ok()) { if ($1.ok()) {
$result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()).release();
} else { } else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
SWIG_fail; SWIG_fail;
} }
} }
%typemap(out) StatusOr<ProgramShape> {
if ($1.ok()) {
$result = numpy::PyProgramShapeInfoFromXlaProgramShape(
$1.ConsumeValueOrDie()).release();
} else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
SWIG_fail;
}
}
%typemap(in) const Shape& (Shape temp) { %typemap(in) const Shape& (Shape temp) {
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input); StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
if (!statusor.ok()) { if (!statusor.ok()) {
@ -558,7 +594,7 @@ tensorflow::ImportNumpy();
} }
%typemap(out) std::unique_ptr<Shape> { %typemap(out) std::unique_ptr<Shape> {
$result = numpy::PyShapeInfoFromXlaShape(*$1); $result = numpy::PyShapeInfoFromXlaShape(*$1).release();
} }
%typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) { %typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) {
@ -966,17 +1002,17 @@ tensorflow::ImportNumpy();
%ignoreall %ignoreall
%unignore xla; %unignore xla;
%unignore xla::swig; %unignore xla::swig;
%unignore xla::swig::InitializeReplicaCount;
%unignore xla::swig::InitializePlatformName;
%unignore xla::swig::GetReplicaCount;
%unignore xla::swig::RegisterCpuCustomCallTarget; %unignore xla::swig::RegisterCpuCustomCallTarget;
%unignore xla::swig::TransferToInfeedLocal; %unignore xla::swig::LocalClient;
%unignore xla::swig::TransferToInfeedLocalReplica; %unignore xla::swig::LocalClient::Get;
%unignore xla::swig::TransferFromOutfeedLocalReplica; %unignore xla::swig::LocalClient::DeviceCount;
%unignore xla::swig::LocalClient::TransferToInfeed;
%unignore xla::swig::LocalClient::TransferFromOutfeed;
%unignore xla::swig::LocalShapedBuffer; %unignore xla::swig::LocalShapedBuffer;
%unignore xla::swig::LocalShapedBuffer::FromLiteral; %unignore xla::swig::LocalShapedBuffer::FromLiteral;
%unignore xla::swig::LocalShapedBuffer::ToLiteral; %unignore xla::swig::LocalShapedBuffer::ToLiteral;
%unignore xla::swig::LocalShapedBuffer::shape; %unignore xla::swig::LocalShapedBuffer::shape;
%unignore xla::swig::LocalShapedBuffer::DestructureTuple;
%unignore xla::swig::LocalShapedBufferTuple; %unignore xla::swig::LocalShapedBufferTuple;
%unignore xla::swig::LocalShapedBufferTuple::Release; %unignore xla::swig::LocalShapedBufferTuple::Release;
%unignore xla::swig::LocalShapedBufferTuple::size; %unignore xla::swig::LocalShapedBufferTuple::size;
@ -987,139 +1023,141 @@ tensorflow::ImportNumpy();
%unignore xla::swig::XrtAllocationTuple; %unignore xla::swig::XrtAllocationTuple;
%unignore xla::swig::XrtAllocationTuple::Release; %unignore xla::swig::XrtAllocationTuple::Release;
%unignore xla::swig::XrtAllocationTuple::size; %unignore xla::swig::XrtAllocationTuple::size;
%unignore xla::swig::CompiledLocalComputation; %unignore xla::swig::LocalExecutable;
%unignore xla::swig::CompiledLocalComputation::Execute; %unignore xla::swig::LocalExecutable::DeviceOrdinals;
%unignore xla::swig::CompiledLocalComputation::ExecutePerReplica; %unignore xla::swig::LocalExecutable::Execute;
%unignore xla::swig::CompiledXrtComputation; %unignore xla::swig::LocalExecutable::ExecutePerReplica;
%unignore xla::swig::CompiledXrtComputation::Execute; %unignore xla::swig::XrtExecutable;
%unignore xla::swig::LocalComputation; %unignore xla::swig::XrtExecutable::DeviceOrdinals;
%unignore xla::swig::LocalComputation::Compile; %unignore xla::swig::XrtExecutable::Execute;
%unignore xla::swig::LocalComputation::CompileForXrt; %unignore xla::swig::Computation;
%unignore xla::swig::LocalComputation::GetReturnValueShape; %unignore xla::swig::Computation::Compile;
%unignore xla::swig::LocalComputation::GetSerializedProto; %unignore xla::swig::Computation::CompileForXrt;
%unignore xla::swig::Computation::GetProgramShape;
%unignore xla::swig::Computation::GetReturnValueShape;
%unignore xla::swig::Computation::GetSerializedProto;
%unignore xla::swig::LocalOp; %unignore xla::swig::LocalOp;
%unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::ComputationBuilder;
%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; %unignore xla::swig::ComputationBuilder::ComputationBuilder;
%unignore xla::swig::LocalComputationBuilder::Build; %unignore xla::swig::ComputationBuilder::Build;
%unignore xla::swig::LocalComputationBuilder::BuildWithRoot; %unignore xla::swig::ComputationBuilder::BuildWithRoot;
%unignore xla::swig::LocalComputationBuilder::SetOpMetadata; %unignore xla::swig::ComputationBuilder::SetOpMetadata;
%unignore xla::swig::LocalComputationBuilder::ClearOpMetadata; %unignore xla::swig::ComputationBuilder::ClearOpMetadata;
%unignore xla::swig::LocalComputationBuilder::Parameter; %unignore xla::swig::ComputationBuilder::Parameter;
%unignore xla::swig::LocalComputationBuilder::GetShape; %unignore xla::swig::ComputationBuilder::GetShape;
%unignore xla::swig::LocalComputationBuilder::GetReturnValueShape; %unignore xla::swig::ComputationBuilder::GetReturnValueShape;
%unignore xla::swig::LocalComputationBuilder::Infeed; %unignore xla::swig::ComputationBuilder::Infeed;
%unignore xla::swig::LocalComputationBuilder::Outfeed; %unignore xla::swig::ComputationBuilder::Outfeed;
%unignore xla::swig::LocalComputationBuilder::ConstantLiteral; %unignore xla::swig::ComputationBuilder::ConstantLiteral;
%unignore xla::swig::LocalComputationBuilder::ConstantR0; %unignore xla::swig::ComputationBuilder::ConstantR0;
%unignore xla::swig::LocalComputationBuilder::Iota; %unignore xla::swig::ComputationBuilder::Iota;
%unignore xla::swig::LocalComputationBuilder::BroadcastedIota; %unignore xla::swig::ComputationBuilder::BroadcastedIota;
%unignore xla::swig::LocalComputationBuilder::Broadcast; %unignore xla::swig::ComputationBuilder::Broadcast;
%unignore xla::swig::LocalComputationBuilder::BroadcastInDim; %unignore xla::swig::ComputationBuilder::BroadcastInDim;
%unignore xla::swig::LocalComputationBuilder::Pad; %unignore xla::swig::ComputationBuilder::Pad;
%unignore xla::swig::LocalComputationBuilder::Reshape; %unignore xla::swig::ComputationBuilder::Reshape;
%unignore xla::swig::LocalComputationBuilder::Collapse; %unignore xla::swig::ComputationBuilder::Collapse;
%unignore xla::swig::LocalComputationBuilder::AllToAll; %unignore xla::swig::ComputationBuilder::AllToAll;
%unignore xla::swig::LocalComputationBuilder::CrossReplicaSum; %unignore xla::swig::ComputationBuilder::CrossReplicaSum;
%unignore xla::swig::LocalComputationBuilder::Slice; %unignore xla::swig::ComputationBuilder::Slice;
%unignore xla::swig::LocalComputationBuilder::SliceInDim; %unignore xla::swig::ComputationBuilder::SliceInDim;
%unignore xla::swig::LocalComputationBuilder::DynamicSlice; %unignore xla::swig::ComputationBuilder::DynamicSlice;
%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice; %unignore xla::swig::ComputationBuilder::DynamicUpdateSlice;
%unignore xla::swig::LocalComputationBuilder::ConcatInDim; %unignore xla::swig::ComputationBuilder::ConcatInDim;
%unignore xla::swig::LocalComputationBuilder::SelectAndScatterWithGeneralPadding; %unignore xla::swig::ComputationBuilder::SelectAndScatterWithGeneralPadding;
%unignore xla::swig::LocalComputationBuilder::Select; %unignore xla::swig::ComputationBuilder::Select;
%unignore xla::swig::LocalComputationBuilder::Tuple; %unignore xla::swig::ComputationBuilder::Tuple;
%unignore xla::swig::LocalComputationBuilder::GetTupleElement; %unignore xla::swig::ComputationBuilder::GetTupleElement;
%unignore xla::swig::LocalComputationBuilder::ConvertElementType; %unignore xla::swig::ComputationBuilder::ConvertElementType;
%unignore xla::swig::LocalComputationBuilder::BitcastConvertType; %unignore xla::swig::ComputationBuilder::BitcastConvertType;
%unignore xla::swig::LocalComputationBuilder::Call; %unignore xla::swig::ComputationBuilder::Call;
%unignore xla::swig::LocalComputationBuilder::Transpose; %unignore xla::swig::ComputationBuilder::Transpose;
%unignore xla::swig::LocalComputationBuilder::Rev; %unignore xla::swig::ComputationBuilder::Rev;
%unignore xla::swig::LocalComputationBuilder::Clamp; %unignore xla::swig::ComputationBuilder::Clamp;
%unignore xla::swig::LocalComputationBuilder::Map; %unignore xla::swig::ComputationBuilder::Map;
%unignore xla::swig::LocalComputationBuilder::Reduce; %unignore xla::swig::ComputationBuilder::Reduce;
%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding; %unignore xla::swig::ComputationBuilder::ReduceWindowWithGeneralPadding;
%unignore xla::swig::LocalComputationBuilder::RngNormal; %unignore xla::swig::ComputationBuilder::RngNormal;
%unignore xla::swig::LocalComputationBuilder::RngUniform; %unignore xla::swig::ComputationBuilder::RngUniform;
%unignore xla::swig::LocalComputationBuilder::RngBernoulli; %unignore xla::swig::ComputationBuilder::RngBernoulli;
%unignore xla::swig::LocalComputationBuilder::While; %unignore xla::swig::ComputationBuilder::While;
%unignore xla::swig::LocalComputationBuilder::Conditional; %unignore xla::swig::ComputationBuilder::Conditional;
%unignore xla::swig::LocalComputationBuilder::IsConstant; %unignore xla::swig::ComputationBuilder::IsConstant;
%unignore xla::swig::LocalComputationBuilder::Eq; %unignore xla::swig::ComputationBuilder::Eq;
%unignore xla::swig::LocalComputationBuilder::Ne; %unignore xla::swig::ComputationBuilder::Ne;
%unignore xla::swig::LocalComputationBuilder::Ge; %unignore xla::swig::ComputationBuilder::Ge;
%unignore xla::swig::LocalComputationBuilder::Gt; %unignore xla::swig::ComputationBuilder::Gt;
%unignore xla::swig::LocalComputationBuilder::Lt; %unignore xla::swig::ComputationBuilder::Lt;
%unignore xla::swig::LocalComputationBuilder::Le; %unignore xla::swig::ComputationBuilder::Le;
%unignore xla::swig::LocalComputationBuilder::Dot; %unignore xla::swig::ComputationBuilder::Dot;
%unignore xla::swig::LocalComputationBuilder::DotGeneral; %unignore xla::swig::ComputationBuilder::DotGeneral;
%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated; %unignore xla::swig::ComputationBuilder::ConvGeneralDilated;
%unignore xla::swig::LocalComputationBuilder::Add; %unignore xla::swig::ComputationBuilder::Add;
%unignore xla::swig::LocalComputationBuilder::Sub; %unignore xla::swig::ComputationBuilder::Sub;
%unignore xla::swig::LocalComputationBuilder::Mul; %unignore xla::swig::ComputationBuilder::Mul;
%unignore xla::swig::LocalComputationBuilder::Div; %unignore xla::swig::ComputationBuilder::Div;
%unignore xla::swig::LocalComputationBuilder::Rem; %unignore xla::swig::ComputationBuilder::Rem;
%unignore xla::swig::LocalComputationBuilder::Max; %unignore xla::swig::ComputationBuilder::Max;
%unignore xla::swig::LocalComputationBuilder::Min; %unignore xla::swig::ComputationBuilder::Min;
%unignore xla::swig::LocalComputationBuilder::And; %unignore xla::swig::ComputationBuilder::And;
%unignore xla::swig::LocalComputationBuilder::Or; %unignore xla::swig::ComputationBuilder::Or;
%unignore xla::swig::LocalComputationBuilder::Xor; %unignore xla::swig::ComputationBuilder::Xor;
%unignore xla::swig::LocalComputationBuilder::ShiftLeft; %unignore xla::swig::ComputationBuilder::ShiftLeft;
%unignore xla::swig::LocalComputationBuilder::ShiftRightArithmetic; %unignore xla::swig::ComputationBuilder::ShiftRightArithmetic;
%unignore xla::swig::LocalComputationBuilder::ShiftRightLogical; %unignore xla::swig::ComputationBuilder::ShiftRightLogical;
%unignore xla::swig::LocalComputationBuilder::Not; %unignore xla::swig::ComputationBuilder::Not;
%unignore xla::swig::LocalComputationBuilder::Abs; %unignore xla::swig::ComputationBuilder::Abs;
%unignore xla::swig::LocalComputationBuilder::Exp; %unignore xla::swig::ComputationBuilder::Exp;
%unignore xla::swig::LocalComputationBuilder::Expm1; %unignore xla::swig::ComputationBuilder::Expm1;
%unignore xla::swig::LocalComputationBuilder::Floor; %unignore xla::swig::ComputationBuilder::Floor;
%unignore xla::swig::LocalComputationBuilder::Ceil; %unignore xla::swig::ComputationBuilder::Ceil;
%unignore xla::swig::LocalComputationBuilder::Round; %unignore xla::swig::ComputationBuilder::Round;
%unignore xla::swig::LocalComputationBuilder::Log; %unignore xla::swig::ComputationBuilder::Log;
%unignore xla::swig::LocalComputationBuilder::Log1p; %unignore xla::swig::ComputationBuilder::Log1p;
%unignore xla::swig::LocalComputationBuilder::Sign; %unignore xla::swig::ComputationBuilder::Sign;
%unignore xla::swig::LocalComputationBuilder::Cos; %unignore xla::swig::ComputationBuilder::Cos;
%unignore xla::swig::LocalComputationBuilder::Sin; %unignore xla::swig::ComputationBuilder::Sin;
%unignore xla::swig::LocalComputationBuilder::Tanh; %unignore xla::swig::ComputationBuilder::Tanh;
%unignore xla::swig::LocalComputationBuilder::Atan2; %unignore xla::swig::ComputationBuilder::Atan2;
%unignore xla::swig::LocalComputationBuilder::IsFinite; %unignore xla::swig::ComputationBuilder::IsFinite;
%unignore xla::swig::LocalComputationBuilder::Pow; %unignore xla::swig::ComputationBuilder::Pow;
%unignore xla::swig::LocalComputationBuilder::Neg; %unignore xla::swig::ComputationBuilder::Neg;
%unignore xla::swig::LocalComputationBuilder::Sort; %unignore xla::swig::ComputationBuilder::Sort;
%unignore xla::swig::LocalComputationBuilder::SortKeyVal; %unignore xla::swig::ComputationBuilder::SortKeyVal;
%unignore xla::swig::LocalComputationBuilder::Sqrt; %unignore xla::swig::ComputationBuilder::Sqrt;
%unignore xla::swig::LocalComputationBuilder::Rsqrt; %unignore xla::swig::ComputationBuilder::Rsqrt;
%unignore xla::swig::LocalComputationBuilder::Square; %unignore xla::swig::ComputationBuilder::Square;
%unignore xla::swig::LocalComputationBuilder::Reciprocal; %unignore xla::swig::ComputationBuilder::Reciprocal;
%unignore xla::swig::LocalComputationBuilder::Erfc; %unignore xla::swig::ComputationBuilder::Erfc;
%unignore xla::swig::LocalComputationBuilder::Erf; %unignore xla::swig::ComputationBuilder::Erf;
%unignore xla::swig::LocalComputationBuilder::ErfInv; %unignore xla::swig::ComputationBuilder::ErfInv;
%unignore xla::swig::LocalComputationBuilder::Lgamma; %unignore xla::swig::ComputationBuilder::Lgamma;
%unignore xla::swig::LocalComputationBuilder::Digamma; %unignore xla::swig::ComputationBuilder::Digamma;
%unignore xla::swig::LocalComputationBuilder::Acos; %unignore xla::swig::ComputationBuilder::Acos;
%unignore xla::swig::LocalComputationBuilder::Asin; %unignore xla::swig::ComputationBuilder::Asin;
%unignore xla::swig::LocalComputationBuilder::Atan; %unignore xla::swig::ComputationBuilder::Atan;
%unignore xla::swig::LocalComputationBuilder::Tan; %unignore xla::swig::ComputationBuilder::Tan;
%unignore xla::swig::LocalComputationBuilder::Acosh; %unignore xla::swig::ComputationBuilder::Acosh;
%unignore xla::swig::LocalComputationBuilder::Asinh; %unignore xla::swig::ComputationBuilder::Asinh;
%unignore xla::swig::LocalComputationBuilder::Atanh; %unignore xla::swig::ComputationBuilder::Atanh;
%unignore xla::swig::LocalComputationBuilder::Cosh; %unignore xla::swig::ComputationBuilder::Cosh;
%unignore xla::swig::LocalComputationBuilder::Sinh; %unignore xla::swig::ComputationBuilder::Sinh;
%unignore xla::swig::LocalComputationBuilder::Real; %unignore xla::swig::ComputationBuilder::Real;
%unignore xla::swig::LocalComputationBuilder::Imag; %unignore xla::swig::ComputationBuilder::Imag;
%unignore xla::swig::LocalComputationBuilder::Conj; %unignore xla::swig::ComputationBuilder::Conj;
%unignore xla::swig::LocalComputationBuilder::Complex; %unignore xla::swig::ComputationBuilder::Complex;
%unignore xla::swig::LocalComputationBuilder::Cholesky; %unignore xla::swig::ComputationBuilder::Cholesky;
%unignore xla::swig::LocalComputationBuilder::QR; %unignore xla::swig::ComputationBuilder::QR;
%unignore xla::swig::LocalComputationBuilder::TriangularSolve; %unignore xla::swig::ComputationBuilder::TriangularSolve;
%unignore xla::swig::LocalComputationBuilder::CustomCall; %unignore xla::swig::ComputationBuilder::CustomCall;
%unignore xla::swig::LocalComputationBuilder::Gather; %unignore xla::swig::ComputationBuilder::Gather;
%unignore xla::swig::LocalComputationBuilder::Scatter; %unignore xla::swig::ComputationBuilder::Scatter;
%unignore xla::swig::DeleteLocalComputation; %unignore xla::swig::DeleteComputation;
%unignore xla::swig::DestructureLocalShapedBufferTuple;
%unignore xla::swig::DestructureXrtAllocationTuple; %unignore xla::swig::DestructureXrtAllocationTuple;
%unignore xla::swig::DeleteLocalShapedBuffer; %unignore xla::swig::DeleteLocalShapedBuffer;
%unignore xla::swig::DeleteXrtAllocation; %unignore xla::swig::DeleteXrtAllocation;
%unignore xla::swig::DeleteCompiledLocalComputation; %unignore xla::swig::DeleteLocalExecutable;
%unignore xla::swig::DeleteCompiledXrtComputation; %unignore xla::swig::DeleteXrtExecutable;
%thread; %thread;
%include "tensorflow/compiler/xla/python/local_computation_builder.h" %include "tensorflow/compiler/xla/python/local_computation_builder.h"

View File

@ -127,28 +127,42 @@ bool NumpyTypeIsValid(int np_type) {
} }
} }
PyObject* PyShapeInfoFromXlaShape(const Shape& shape) { Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape) {
int np_typenum = PrimitiveTypeToNumpyType(shape.element_type()); int np_typenum = PrimitiveTypeToNumpyType(shape.element_type());
PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum); PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum);
PyObject* dimensions; Safe_PyObjectPtr dimensions;
if (shape.IsTuple()) { if (shape.IsTuple()) {
int num_elements = ShapeUtil::TupleElementCount(shape); int num_elements = ShapeUtil::TupleElementCount(shape);
dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape)); dimensions = make_safe(PyTuple_New(ShapeUtil::TupleElementCount(shape)));
for (int i = 0; i < num_elements; ++i) { for (int i = 0; i < num_elements; ++i) {
PyTuple_SET_ITEM( PyTuple_SET_ITEM(
dimensions, i, dimensions.get(), i,
PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))); PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))
.release());
} }
} else { } else {
int rank = shape.rank(); int rank = shape.rank();
dimensions = PyTuple_New(rank); dimensions = make_safe(PyTuple_New(rank));
for (int i = 0; i < rank; ++i) { for (int i = 0; i < rank; ++i) {
PyTuple_SET_ITEM(dimensions, i, PyTuple_SET_ITEM(dimensions.get(), i,
LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i))); LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i)));
} }
} }
return PyTuple_Pack(2, np_dtype, dimensions); return make_safe(PyTuple_Pack(2, np_dtype, dimensions.release()));
}
Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape(
const ProgramShape& shape) {
Safe_PyObjectPtr arg_shapes = make_safe(PyTuple_New(shape.parameters_size()));
for (int i = 0; i < shape.parameters_size(); ++i) {
PyTuple_SET_ITEM(arg_shapes.get(), i,
PyShapeInfoFromXlaShape(shape.parameters(i)).release());
}
Safe_PyObjectPtr result_shape = PyShapeInfoFromXlaShape(shape.result());
return make_safe(
PyTuple_Pack(2, arg_shapes.release(), result_shape.release()));
} }
// Precondition: o->ob_type == &PyArrayDescr_Type // Precondition: o->ob_type == &PyArrayDescr_Type

View File

@ -64,7 +64,13 @@ bool NumpyTypeIsValid(int np_type);
// providing the array dimensions. // providing the array dimensions.
// //
// The return value is a new reference. // The return value is a new reference.
PyObject* PyShapeInfoFromXlaShape(const Shape& shape); Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape);
// Returns a pair of (arg_shapes, result_shape), where arg_shapes is a tuple
// of argument shapes and result_shape is the result shape. Each shape is as
// described in in PyShapeInfoFromXlaShape's comment.
Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape(
const ProgramShape& shape);
// Converts a Python object with a method interface mathing that of // Converts a Python object with a method interface mathing that of
// xla_client.Shape into an XLA Shape object. // xla_client.Shape into an XLA Shape object.

View File

@ -36,7 +36,7 @@ from tensorflow.compiler.xla.service import hlo_pb2
# Most functions are snake_case for consistency with other modules, whereas # Most functions are snake_case for consistency with other modules, whereas
# method names of ComputationBuilder and LocalComputation are CamelCase for # method names of ComputationBuilder and Computation are CamelCase for
# consistency with XLA. # consistency with XLA.
# pylint: disable=invalid-name # pylint: disable=invalid-name
@ -50,7 +50,7 @@ from tensorflow.compiler.xla.service import hlo_pb2
# which case we need to be able to detect when incompatible versions are # which case we need to be able to detect when incompatible versions are
# installed. # installed.
def version(): def version():
return (0, 1, 7) return (0, 1, 8)
_OP_METADATA_FIELDS = [ _OP_METADATA_FIELDS = [
@ -66,6 +66,10 @@ OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS)
class Backend(object): class Backend(object):
"""Abstract base class for XLA backends.""" """Abstract base class for XLA backends."""
@abc.abstractmethod
def device_count(self):
"""Returns the number of devices known to the backend."""
@abc.abstractmethod @abc.abstractmethod
def buffer_from_pyval(self, pyval, device=0): def buffer_from_pyval(self, pyval, device=0):
"""Allocates a fresh buffer and populates it with `pyval`.""" """Allocates a fresh buffer and populates it with `pyval`."""
@ -95,25 +99,39 @@ class Backend(object):
"""Runs an executable in a replicated manner.""" """Runs an executable in a replicated manner."""
def _maybe_encode_string(s):
if six.PY3:
return s.encode('utf-8')
else:
return s
class XlaLocalBackend(Backend): class XlaLocalBackend(Backend):
"""XLA backend implemented using the in-process xla::LocalClient API.""" """XLA backend implemented using the in-process xla::LocalClient API."""
def __init__(self, platform=None):
platform = platform or _get_default_platform_name()
self.client = c_api.LocalClient.Get(_maybe_encode_string(platform))
def device_count(self):
return self.client.DeviceCount()
def buffer_from_pyval(self, pyval, device=0): def buffer_from_pyval(self, pyval, device=0):
return c_api.LocalShapedBuffer.FromLiteral(pyval, None, device) return c_api.LocalShapedBuffer.FromLiteral(pyval, None, self.client, device)
def delete_buffer(self, c_buffer): def delete_buffer(self, c_buffer):
c_api.DeleteLocalShapedBuffer(c_buffer) c_api.DeleteLocalShapedBuffer(c_buffer)
def destructure_tuple(self, c_buffer): def destructure_tuple(self, c_buffer):
result = c_api.DestructureLocalShapedBufferTuple(c_buffer) result = c_buffer.DestructureTuple()
return [result.Release(i) for i in xrange(result.size())] return [result.Release(i) for i in xrange(result.size())]
def compile(self, c_computation, argument_shapes, compile_options): def compile(self, c_computation, argument_shapes, compile_options):
return c_computation.Compile(argument_shapes, compile_options) return c_computation.Compile(argument_shapes, compile_options, self.client)
def delete_executable(self, executable): def delete_executable(self, executable):
assert isinstance(executable, c_api.CompiledLocalComputation) assert isinstance(executable, c_api.LocalExecutable)
c_api.DeleteCompiledLocalComputation(executable) c_api.DeleteLocalExecutable(executable)
def execute(self, executable, args): def execute(self, executable, args):
return executable.Execute(args) return executable.Execute(args)
@ -130,6 +148,9 @@ class XrtBackend(Backend):
def __init__(self, target): def __init__(self, target):
self.target = target self.target = target
def device_count(self):
return 1 # Multidevice execution not implemented.
def buffer_from_pyval(self, pyval, device=0): def buffer_from_pyval(self, pyval, device=0):
if device != 0: if device != 0:
raise NotImplementedError( raise NotImplementedError(
@ -150,8 +171,8 @@ class XrtBackend(Backend):
_maybe_encode_string(self.target)) _maybe_encode_string(self.target))
def delete_executable(self, executable): def delete_executable(self, executable):
assert isinstance(executable, c_api.CompiledXrtComputation) assert isinstance(executable, c_api.XrtExecutable)
c_api.DeleteCompiledXrtComputation(executable) c_api.DeleteXrtExecutable(executable)
def execute(self, executable, args): def execute(self, executable, args):
return executable.Execute(args) return executable.Execute(args)
@ -163,7 +184,20 @@ class XrtBackend(Backend):
return [executable.Execute(per_replica_args[0])] return [executable.Execute(per_replica_args[0])]
XLA_LOCAL_BACKEND = XlaLocalBackend() _default_platform_name = 'Host'
_default_backend = None
def _get_default_platform_name():
return _default_platform_name
def _get_default_local_backend():
global _default_backend
global _default_platform_name
if _default_backend is None:
_default_backend = XlaLocalBackend(_default_platform_name)
return _default_backend
class BackendType(enum.Enum): class BackendType(enum.Enum):
@ -174,7 +208,7 @@ class BackendType(enum.Enum):
def BackendSpec(backend, target): def BackendSpec(backend, target):
"""Compatibility wrapper to support older clients. Do not use in new code.""" """Compatibility wrapper to support older clients. Do not use in new code."""
if backend == BackendType.XLA_LOCAL: if backend == BackendType.XLA_LOCAL:
return XLA_LOCAL_BACKEND return _get_default_local_backend()
elif backend == BackendType.XRT: elif backend == BackendType.XRT:
return XrtBackend(target) return XrtBackend(target)
else: else:
@ -201,13 +235,6 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
source_line=lineno) source_line=lineno)
def _maybe_encode_string(s):
if six.PY3:
return s.encode('utf-8')
else:
return s
class PaddingType(enum.Enum): class PaddingType(enum.Enum):
VALID = 1 VALID = 1
SAME = 2 SAME = 2
@ -346,22 +373,18 @@ class LocalBuffer(object):
means the referent is in device memory. means the referent is in device memory.
""" """
def __init__(self, c_buffer, backend, replica): def __init__(self, c_buffer, backend, device):
self.c_buffer = c_buffer self.c_buffer = c_buffer
self._backend = backend self._backend = backend
self._replica = replica self._device = device
@staticmethod @staticmethod
def from_pyval(pyval, replica=0, backend=XLA_LOCAL_BACKEND): def from_pyval(pyval, device=0, backend=None):
"""Allocate and copy to XLA the given python value.""" """Allocate and copy to XLA the given python value."""
backend = backend or _get_default_local_backend()
pyval = require_numpy_array_layout(pyval) pyval = require_numpy_array_layout(pyval)
num_replicas = get_replica_count() cbuf = backend.buffer_from_pyval(pyval, device)
if not 0 <= replica < num_replicas: return LocalBuffer(cbuf, backend, device)
raise ValueError(
'Attempt to place buffer on replica {} when the replica count is {}'
.format(replica, num_replicas))
cbuf = backend.buffer_from_pyval(pyval, replica)
return LocalBuffer(cbuf, backend, replica)
def to_py(self): def to_py(self):
return self.c_buffer.ToLiteral() return self.c_buffer.ToLiteral()
@ -369,8 +392,8 @@ class LocalBuffer(object):
def shape(self): def shape(self):
return _wrap_shape(self.c_buffer.shape()) return _wrap_shape(self.c_buffer.shape())
def replica(self): def device(self):
return self._replica return self._device
def delete(self): def delete(self):
if self.c_buffer is not None: if self.c_buffer is not None:
@ -383,7 +406,7 @@ class LocalBuffer(object):
result = self._backend.destructure_tuple(self.c_buffer) result = self._backend.destructure_tuple(self.c_buffer)
self.delete() self.delete()
return tuple( return tuple(
LocalBuffer(sub_buffer, replica=self._replica, backend=self._backend) LocalBuffer(sub_buffer, device=self._device, backend=self._backend)
for sub_buffer in result) for sub_buffer in result)
def is_deleted(self): def is_deleted(self):
@ -533,6 +556,16 @@ class Shape(object):
updated._check_minor_to_major() # pylint: disable=protected-access updated._check_minor_to_major() # pylint: disable=protected-access
return updated return updated
def with_major_to_minor_layout_if_absent(self):
"""Returns a copy of a shape with missing layouts set to major-to-minor."""
def f(a):
if a.minor_to_major():
return None
return a.update_minor_to_major(tuple(xrange(a.rank() - 1, -1, -1)))
return self.map_leaves(f)
def serialize(self, proto): def serialize(self, proto):
"""Serializes 'shape' into proto.""" """Serializes 'shape' into proto."""
if self.is_tuple(): if self.is_tuple():
@ -548,6 +581,10 @@ class Shape(object):
proto.layout.minor_to_major.extend(self.minor_to_major()) proto.layout.minor_to_major.extend(self.minor_to_major())
ProgramShape = collections.namedtuple('ProgramShape',
('parameter_shapes', 'result_shape'))
def _wrap_shape(shape_info): def _wrap_shape(shape_info):
dtype, dims = shape_info dtype, dims = shape_info
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)]
@ -581,7 +618,7 @@ class CompileOptions(object):
self.num_replicas = get_replica_count() self.num_replicas = get_replica_count()
def transfer_to_infeed(value, replica_number=None): def transfer_to_infeed(value, device_ordinal=0):
"""Transfers the given value into the XLA infeed queue. """Transfers the given value into the XLA infeed queue.
XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with
@ -591,52 +628,49 @@ def transfer_to_infeed(value, replica_number=None):
Args: Args:
value: the value that the caller would like to enqueue into the XLA infeed value: the value that the caller would like to enqueue into the XLA infeed
queue queue
replica_number: the replica number to infeed the value to -- if not device_ordinal: the device to infeed the value to. Each device has a
provided, then the default replica (trivially replica 0) is used. distinct infeed queue.
""" """
if replica_number is None: # TODO(phawkins): support non-default backends.
c_api.TransferToInfeedLocal(require_numpy_array_layout(value)) backend = _get_default_local_backend()
else: backend.client.TransferToInfeed(
c_api.TransferToInfeedLocalReplica( require_numpy_array_layout(value), device_ordinal)
require_numpy_array_layout(value), replica_number)
def transfer_from_outfeed(shape, replica_number=None): def transfer_from_outfeed(shape, device_ordinal=0):
"""Transfers a literal of the given shape from replica_number's outfeed. """Transfers a literal of the given shape from `device_ordinal`'s outfeed.
Args: Args:
shape: The shape of the value to transfer from outfeed. shape: The shape of the value to transfer from outfeed.
replica_number: The replica number ordinal to transfer the outfeed value device_ordinal: The device ordinal to transfer the outfeed value from. Each
from. (Each replica has a distinct outfeed queue.) device has a distinct outfeed queue..
Returns: Returns:
The literal value that is produced from the outfeed queue. The literal value that is produced from the outfeed queue.
""" """
return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0) # TODO(phawkins): support non-default backends.
backend = _get_default_local_backend()
return backend.client.TransferFromOutfeed(shape, device_ordinal)
class LocalComputation(object): class Computation(object):
"""Python wrapper for a local XLA Computation. """Python wrapper for an XLA Computation.
A LocalComputation can be executed if it is compiled. Otherwise, it A Computation can be compiled to form an Executable, or used as a
can still be used as a Computation where required by the subcomputation in ComputationBuilder methods.
ComputationBuilder methods.
""" """
def __init__(self, c_computation, is_compiled, backend=XLA_LOCAL_BACKEND): def __init__(self, c_computation, backend=None):
self._c_computation = c_computation self._c_computation = c_computation
# The backend argument is deprecated. Pass a backend to Compile() instead.
self._backend = backend self._backend = backend
self._is_compiled = is_compiled
@property @property
def computation(self): def computation(self):
if self._is_compiled:
raise ValueError(
'Attempt to read the XLA computation of a compiled LocalComputation.')
return self._c_computation return self._c_computation
def GetProto(self): def GetProto(self):
"""Get the HloModuleProto proto object in this local computation. """Get the HloModuleProto proto object in this computation.
Returns: Returns:
An HloModuleProto proto object that has the whole-graph information. An HloModuleProto proto object that has the whole-graph information.
@ -645,30 +679,25 @@ class LocalComputation(object):
proto = hlo_pb2.HloModuleProto.FromString(serialized) proto = hlo_pb2.HloModuleProto.FromString(serialized)
return proto return proto
def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None,
"""Compiles an un-compiled local computation. backend=None):
"""Compiles a computation.
Local computations are the result of a "LocalComputationBuild'ing" process Computations are the result of a "ComputationBuild'ing" process.
-- they start in uncompiled form, and via a call to Compile() turn into a
compiled local computation.
Raises:
ValueError: if this is already a compiled local computation.
Arguments: Arguments:
argument_shapes: parameter shapes -- they are first laid out by layout_fn argument_shapes: parameter shapes -- they are first laid out by layout_fn
if layout_fn is provided. Otherwise, the default layout for those shapes if layout_fn is provided. Otherwise, the default layout for those shapes
will be used. will be used.
compile_options: options to use for compilation, includes an optional compile_options: options to use for compilation, includes an optional laid
laid out result shape for the computation. out result shape for the computation.
layout_fn: lambda that is used to lay out the argument/result shapes. layout_fn: lambda that is used to lay out the argument/result shapes.
backend: a `Backend` for which an executable should be generated.
Returns: Returns:
A newly *compiled* local computation instance. A Executable instance.
""" """
if self._is_compiled: backend = backend or self._backend or _get_default_local_backend()
raise ValueError('Attempt to compile a compiled local XLA computation.')
result_shape = _wrap_shape(self.computation.GetReturnValueShape()) result_shape = _wrap_shape(self.computation.GetReturnValueShape())
if layout_fn: if layout_fn:
@ -681,29 +710,55 @@ class LocalComputation(object):
compile_options = compile_options or CompileOptions() compile_options = compile_options or CompileOptions()
compile_options.result_shape = result_shape compile_options.result_shape = result_shape
c = self._backend.compile(self.computation, argument_shapes, c = backend.compile(self.computation, argument_shapes, compile_options)
compile_options) return Executable(c, backend=backend)
return LocalComputation(c, is_compiled=True, backend=self._backend)
def CompileWithExampleArguments(self, def CompileWithExampleArguments(self,
arguments=(), arguments=(),
compile_options=None, compile_options=None,
layout_fn=None): layout_fn=None,
backend=None):
return self.Compile( return self.Compile(
argument_shapes=[Shape.from_pyval(arg) for arg in arguments], argument_shapes=[Shape.from_pyval(arg) for arg in arguments],
compile_options=compile_options, compile_options=compile_options,
layout_fn=layout_fn) layout_fn=layout_fn,
backend=backend)
def GetProgramShape(self):
(arg_shapes, result_shape) = self._c_computation.GetProgramShape()
return ProgramShape([_wrap_shape(arg) for arg in arg_shapes],
_wrap_shape(result_shape))
def GetReturnValueShape(self): def GetReturnValueShape(self):
return _wrap_shape(self._c_computation.GetReturnValueShape()) return _wrap_shape(self._c_computation.GetReturnValueShape())
def __del__(self):
# Python may have freed c_api first.
if c_api and self._c_computation:
assert isinstance(self._c_computation, c_api.Computation)
c_api.DeleteComputation(self._c_computation)
class Executable(object):
"""Python wrapper for an XLA Executable."""
def __init__(self, c_executable, backend=None):
self._c_executable = c_executable
self._device_ordinals = c_executable.DeviceOrdinals()
self._backend = backend
def DeviceOrdinals(self):
"""Returns a list containing the device ordinals for each replica."""
return self._device_ordinals
def Execute(self, arguments=(), check_for_deleted_args=True): def Execute(self, arguments=(), check_for_deleted_args=True):
"""Execute on one replica with LocalBuffer arguments and return value.""" """Execute on one replica with LocalBuffer arguments and return value."""
if check_for_deleted_args and any(arg.is_deleted() for arg in arguments): if check_for_deleted_args and any(arg.is_deleted() for arg in arguments):
raise ValueError('Executing with deleted local buffer argument') raise ValueError('Executing with deleted local buffer argument')
raw_args = [arg.c_buffer for arg in arguments] raw_args = [arg.c_buffer for arg in arguments]
output_buffer = self._backend.execute(self._c_computation, raw_args) output_buffer = self._backend.execute(self._c_executable, raw_args)
return LocalBuffer(output_buffer, backend=self._backend, replica=0) return LocalBuffer(
output_buffer, backend=self._backend, device=self._device_ordinals[0])
def ExecutePerReplica(self, arguments=None): def ExecutePerReplica(self, arguments=None):
"""Execute on many replicas with LocalBuffer arguments and return value. """Execute on many replicas with LocalBuffer arguments and return value.
@ -713,14 +768,12 @@ class LocalComputation(object):
sequence comprises the arguments for execution on the i'th replica. sequence comprises the arguments for execution on the i'th replica.
Returns: Returns:
A list of the computation's outputs on each replica, as a LocalBuffer. If A list of the computation's outputs for each replica, as a LocalBuffer. If
a shallow sequence of arguments was passed in for `arguments`, then the a shallow sequence of arguments was passed in for `arguments`, then the
sole, zero'th replica's output is returned instead, as a LocalBuffer. sole, zero'th replica's output is returned instead, as a LocalBuffer.
""" """
if not self._is_compiled:
raise ValueError('Cannot execute an uncompiled local XLA computation.')
if arguments is None: if arguments is None:
arguments = ((),) * get_replica_count() arguments = ((),) * len(self._device_ordinals)
else: else:
arguments = [list(replica_args) for replica_args in arguments] arguments = [list(replica_args) for replica_args in arguments]
@ -729,30 +782,35 @@ class LocalComputation(object):
for arg in replica_args: for arg in replica_args:
if arg.is_deleted(): if arg.is_deleted():
raise ValueError('Executing with deleted local buffer argument') raise ValueError('Executing with deleted local buffer argument')
if arg.replica() != replica: if arg.device() != self._device_ordinals[replica]:
raise ValueError( raise ValueError(
'Executing on replica {} with argument from replica {}'.format( 'Executing on device {} with argument from device {}'.format(
replica, arg.replica())) self._device_ordinals[replica], arg.device()))
# Pull out argument buffer handles # Pull out argument buffer handles
# pylint: disable=g-complex-comprehension
stripped_args = [ stripped_args = [
[arg.c_buffer for arg in replica_args] for replica_args in arguments [arg.c_buffer for arg in replica_args] for replica_args in arguments
] ]
# Execute # Execute
output_buffers = self._backend.execute_replicated( output_buffers = self._backend.execute_replicated(self._c_executable,
self._c_computation, stripped_args) stripped_args)
# Wrap output handles in LocalBuffer instances # Wrap output handles in LocalBuffer instances
return tuple( return tuple(
LocalBuffer(output_buffer, backend=self._backend, replica=replica) LocalBuffer(
output_buffer,
backend=self._backend,
device=self._device_ordinals[replica])
for replica, output_buffer in enumerate(output_buffers)) for replica, output_buffer in enumerate(output_buffers))
def ExecuteWithPythonValues(self, arguments=()): def ExecuteWithPythonValues(self, arguments=()):
"""Execute on one replica with Python values as arguments and output.""" """Execute on one replica with Python values as arguments and output."""
def put(arg): def put(arg):
return LocalBuffer.from_pyval(arg, backend=self._backend) return LocalBuffer.from_pyval(
arg, device=self._device_ordinals[0], backend=self._backend)
arguments = [put(arg) for arg in arguments] arguments = [put(arg) for arg in arguments]
return self.Execute(arguments).to_py() return self.Execute(arguments).to_py()
@ -760,22 +818,19 @@ class LocalComputation(object):
def ExecuteWithPythonValuesPerReplica(self, arguments): def ExecuteWithPythonValuesPerReplica(self, arguments):
"""Execute on many replicas with Python values as arguments and output.""" """Execute on many replicas with Python values as arguments and output."""
def put(arg, replica): def put(arg, device):
return LocalBuffer.from_pyval(arg, replica, backend=self._backend) return LocalBuffer.from_pyval(arg, device, backend=self._backend)
arguments = [[put(arg, replica) # pylint: disable=g-complex-comprehension
for arg in replica_args] arguments = [[
for replica, replica_args in enumerate(arguments)] put(arg, self._device_ordinals[replica]) for arg in replica_args
] for replica, replica_args in enumerate(arguments)]
return [out.to_py() for out in self.ExecutePerReplica(arguments)] return [out.to_py() for out in self.ExecutePerReplica(arguments)]
def __del__(self): def __del__(self):
# Python may have freed c_api first. # Python may have freed c_api first.
if c_api and self._c_computation: if c_api and self._c_executable:
if self._is_compiled: self._backend.delete_executable(self._c_executable)
self._backend.delete_executable(self._c_computation)
else:
assert isinstance(self._c_computation, c_api.LocalComputation)
c_api.DeleteLocalComputation(self._c_computation)
def _make_replica_group_proto(replica_group): def _make_replica_group_proto(replica_group):
@ -788,8 +843,8 @@ class ComputationBuilder(object):
"""XLA computation builder. """XLA computation builder.
Enqueues XLA ops in sequence and in order to build a Enqueues XLA ops in sequence and in order to build a
LocalComputation, which in turn can be compiled into a Computation, which in turn can be compiled into a
CompiledLocalComputation, which in turn can be locally executed. LocalExecutable, which in turn can be locally executed.
""" """
# The methods of this class map 1-to-1 onto the XLA C++ # The methods of this class map 1-to-1 onto the XLA C++
@ -800,16 +855,23 @@ class ComputationBuilder(object):
# pylint: disable=g-doc-args # pylint: disable=g-doc-args
def __init__(self, name): def __init__(self, name):
self._client = c_api.LocalComputationBuilder(name.encode('utf8')) self._client = c_api.ComputationBuilder(name.encode('utf8'))
self._parameter_numbering = itertools.count() self._parameter_numbering = itertools.count()
def Build(self, root=None, backend=XLA_LOCAL_BACKEND): def Build(self, root=None, backend=None):
"""Builds a `Computation` from the contents of the builder.
Args:
root: if not None, the operator containing the return value of the
computation.
backend: deprecated. Pass a `backend` to `Computation.Compile` instead.
Returns:
A `Computation`.
"""
if root is not None: if root is not None:
return LocalComputation( return Computation(self._client.BuildWithRoot(root), backend=backend)
self._client.BuildWithRoot(root), is_compiled=False, backend=backend)
else: else:
return LocalComputation( return Computation(self._client.Build(), backend=backend)
self._client.Build(), is_compiled=False, backend=backend)
def SetOpMetadata(self, op_metadata): def SetOpMetadata(self, op_metadata):
"""Set metadata for operations that are about to be enqueued.""" """Set metadata for operations that are about to be enqueued."""
@ -1461,7 +1523,7 @@ class ComputationBuilder(object):
Args: Args:
operand: a LocalOp to test. operand: a LocalOp to test.
Returns: a LocalComputation that is rooted on the given `operand` which is a Returns: a Computation that is rooted on the given `operand` which is a
compile-time constant. compile-time constant.
""" """
return self._client.BuildConstantSubGraph(operand) return self._client.BuildConstantSubGraph(operand)
@ -1662,7 +1724,7 @@ def _forward_methods_to_local_builder():
Set up methods, corresponding to unary and binary XLA operations, Set up methods, corresponding to unary and binary XLA operations,
whose calls are forwarded in a boilerplate manner to the underlying whose calls are forwarded in a boilerplate manner to the underlying
LocalComputationBuilder C-extension API. ComputationBuilder C-extension API.
""" """
def forward_to_local_builder_with_handles(target_method, is_binop=False): def forward_to_local_builder_with_handles(target_method, is_binop=False):
@ -1682,13 +1744,13 @@ def _forward_methods_to_local_builder():
for method_name in _UNARY_OPS: for method_name in _UNARY_OPS:
forward = forward_to_local_builder_with_handles( forward = forward_to_local_builder_with_handles(
getattr(c_api.LocalComputationBuilder, method_name)) getattr(c_api.ComputationBuilder, method_name))
forward.__name__ = method_name forward.__name__ = method_name
setattr(ComputationBuilder, method_name, forward) setattr(ComputationBuilder, method_name, forward)
for method_name in _BINARY_OPS: for method_name in _BINARY_OPS:
forward = forward_to_local_builder_with_handles( forward = forward_to_local_builder_with_handles(
getattr(c_api.LocalComputationBuilder, method_name), is_binop=True) getattr(c_api.ComputationBuilder, method_name), is_binop=True)
forward.__name__ = method_name forward.__name__ = method_name
setattr(ComputationBuilder, method_name, forward) setattr(ComputationBuilder, method_name, forward)
@ -1696,8 +1758,14 @@ def _forward_methods_to_local_builder():
_forward_methods_to_local_builder() _forward_methods_to_local_builder()
_default_replica_count = 1
def initialize_replica_count(replica_count): def initialize_replica_count(replica_count):
"""Initializes the desired replica count to use on XLA service init. """Initializes the default replica count to use.
Deprecated; pass `num_replicas` as an option to `Computation.Compile()`
instead.
Args: Args:
replica_count: number of replicas that are desired for set up during XLA replica_count: number of replicas that are desired for set up during XLA
@ -1706,31 +1774,30 @@ def initialize_replica_count(replica_count):
Raises: Raises:
A runtime exception if the XLA service has already been initialized. A runtime exception if the XLA service has already been initialized.
""" """
c_api.InitializeReplicaCount(replica_count) global _default_replica_count
_default_replica_count = replica_count
def initialize_platform_name(platform_name):
"""Initializes the desired platform name to use on XLA service init.
Args:
platform_name: string name of platform.
Raises:
A runtime exception if the XLA service has already been initialized.
A runtime exception if the platform does not exist, or there are no devices
with that platform.
"""
platform_name = _maybe_encode_string(platform_name)
c_api.InitializePlatformName(platform_name)
def get_replica_count(): def get_replica_count():
"""Returns the current replica count used for the XLA service. """Returns the default replica count.
Note: this will return a value whether the XLA service has been initialized Deprecated; pass `num_replicas` as an option to `Computation.Compile()`
yet or not. instead.
""" """
return c_api.GetReplicaCount() return _default_replica_count
def initialize_platform_name(platform_name):
"""Initializes the default platform name to use for XLA.
Args:
platform_name: string name of platform.
"""
global _default_platform_name
_default_platform_name = platform_name
# Make sure the platform is valid by trying to instantiate it.
_get_default_local_backend()
def register_cpu_custom_call_target(name, fn): def register_cpu_custom_call_target(name, fn):

View File

@ -29,7 +29,7 @@ from tensorflow.compiler.xla.python import xla_client
import unittest import unittest
class LocalComputationTest(unittest.TestCase): class ComputationTest(unittest.TestCase):
"""Base class for running an XLA Computation through the local client.""" """Base class for running an XLA Computation through the local client."""
def _NewComputation(self, name=None): def _NewComputation(self, name=None):
@ -85,7 +85,7 @@ def NumpyArrayBool(*args, **kwargs):
return np.array(*args, dtype=np.bool, **kwargs) return np.array(*args, dtype=np.bool, **kwargs)
class ComputationsWithConstantsTest(LocalComputationTest): class ComputationsWithConstantsTest(ComputationTest):
"""Tests focusing on Constant ops.""" """Tests focusing on Constant ops."""
def testConstantScalarSumS8(self): def testConstantScalarSumS8(self):
@ -304,7 +304,7 @@ class ComputationsWithConstantsTest(LocalComputationTest):
self._ExecuteAndCompareClose(c, expected=0.75) self._ExecuteAndCompareClose(c, expected=0.75)
class ParametersTest(LocalComputationTest): class ParametersTest(ComputationTest):
"""Tests focusing on Parameter ops and argument-passing.""" """Tests focusing on Parameter ops and argument-passing."""
def setUp(self): def setUp(self):
@ -384,7 +384,7 @@ class ParametersTest(LocalComputationTest):
expected=[-4.3, 1.3, -6.3, 3.3]) expected=[-4.3, 1.3, -6.3, 3.3])
class LocalBufferTest(LocalComputationTest): class LocalBufferTest(ComputationTest):
"""Tests focusing on execution with LocalBuffers.""" """Tests focusing on execution with LocalBuffers."""
def _Execute(self, c, arguments): def _Execute(self, c, arguments):
@ -482,7 +482,7 @@ class LocalBufferTest(LocalComputationTest):
self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
class SingleOpTest(LocalComputationTest): class SingleOpTest(ComputationTest):
"""Tests for single ops. """Tests for single ops.
The goal here is smoke testing - to exercise the most basic functionality of The goal here is smoke testing - to exercise the most basic functionality of
@ -1175,7 +1175,7 @@ class SingleOpTest(LocalComputationTest):
np.testing.assert_allclose(g, expected, rtol=1e-4) np.testing.assert_allclose(g, expected, rtol=1e-4)
class EmbeddedComputationsTest(LocalComputationTest): class EmbeddedComputationsTest(ComputationTest):
"""Tests for XLA graphs with embedded computations (such as maps).""" """Tests for XLA graphs with embedded computations (such as maps)."""
def _CreateConstantS32Computation(self): def _CreateConstantS32Computation(self):
@ -1639,7 +1639,7 @@ class EmbeddedComputationsTest(LocalComputationTest):
self._ExecuteAndCompareClose(c, expected=expected) self._ExecuteAndCompareClose(c, expected=expected)
class ErrorTest(LocalComputationTest): class ErrorTest(ComputationTest):
def setUp(self): def setUp(self):
self.f32_scalar_2 = NumpyArrayF32(2.0) self.f32_scalar_2 = NumpyArrayF32(2.0)
@ -1656,7 +1656,7 @@ class ErrorTest(LocalComputationTest):
lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2])) lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2]))
class ComputationRootTest(LocalComputationTest): class ComputationRootTest(ComputationTest):
"""Tests related to setting the root of the computation.""" """Tests related to setting the root of the computation."""
def testComputationRootDifferentFromLastOp(self): def testComputationRootDifferentFromLastOp(self):

View File

@ -3529,6 +3529,37 @@ tf_cc_test(
], ],
) )
cc_library(
name = "stable_sort_expander",
srcs = ["stable_sort_expander.cc"],
hdrs = ["stable_sort_expander.h"],
deps = [
":hlo",
":hlo_casting_utils",
":hlo_pass",
":op_expander_pass",
"//tensorflow/compiler/xla:statusor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
],
)
tf_cc_test(
name = "stable_sort_expander_test",
srcs = ["stable_sort_expander_test.cc"],
deps = [
":algebraic_simplifier",
":hlo_matchers",
":hlo_parser",
":pattern_matcher",
":pattern_matcher_gmock",
":stable_sort_expander",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test",
],
)
cc_library( cc_library(
name = "tuple_util", name = "tuple_util",
srcs = ["tuple_util.cc"], srcs = ["tuple_util.cc"],

View File

@ -280,15 +280,51 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
hlo)); hlo));
} }
// Helper method to perform and add reduction in a single dimension. // Converts to primitive type if the input hlo is not that type, otherwise
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { // returns the original hlo.
HloInstruction* AsType(HloInstruction* hlo,
const PrimitiveType element_type) {
if (hlo->shape().element_type() == element_type) {
return hlo;
}
return computation_->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo));
}
// Transposes a dot operand such that the batch dimensions are the msot major,
// and the contracting dimensions are most minor.
StatusOr<HloInstruction*> NormalizeDotOperandToBatchMajorAndContractingMinor(
HloInstruction* dot_operand, absl::Span<const int64> batch_dimensions,
absl::Span<const int64> contracting_dimensions) {
std::vector<int64> transpose_dimensions(batch_dimensions.begin(),
batch_dimensions.end());
for (int64 i = 0; i < dot_operand->shape().rank(); ++i) {
if (!(absl::c_linear_search(batch_dimensions, i) ||
absl::c_linear_search(contracting_dimensions, i))) {
transpose_dimensions.push_back(i);
}
}
transpose_dimensions.insert(transpose_dimensions.end(),
contracting_dimensions.begin(),
contracting_dimensions.end());
return MakeTransposeHlo(dot_operand, transpose_dimensions);
}
// Helper method to perform and add reduction on a list of dimensions.
HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64> dims) {
HloInstruction* zero = HloInstruction* zero =
computation_->AddInstruction(HloInstruction::CreateConstant( computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type()).Clone())); LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); Shape shape = ShapeUtil::FilterDimensions(
[&](int64 dim) { return !absl::c_linear_search(dims, dim); },
hlo->shape());
return computation_->AddInstruction(HloInstruction::CreateReduce( return computation_->AddInstruction(HloInstruction::CreateReduce(
shape, hlo, zero, {dim}, AddReduce_computation)); shape, hlo, zero, dims, AddReduce_computation));
}
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
return AddReduce(hlo, std::vector<int64>{dim});
} }
// Convenience method for replacing an instruction with a bitcast. If operand // Convenience method for replacing an instruction with a bitcast. If operand
@ -1120,16 +1156,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
std::swap(rhs_collapsing_dim, rhs_kept_dim); std::swap(rhs_collapsing_dim, rhs_kept_dim);
} }
auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) {
if (hlo->shape().element_type() == element_type) {
return hlo;
}
return computation_->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo));
};
auto reshape_if_necessary = [&](HloInstruction* hlo) { auto reshape_if_necessary = [&](HloInstruction* hlo) {
hlo = as_type(hlo, dot->shape().element_type()); hlo = AsType(hlo, dot->shape().element_type());
if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
hlo = computation_->AddInstruction( hlo = computation_->AddInstruction(
HloInstruction::CreateReshape(dot->shape(), hlo)); HloInstruction::CreateReshape(dot->shape(), hlo));
@ -1138,7 +1166,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
}; };
auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) { auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) {
return AddReduce(as_type(hlo, F32), dim); return AddReduce(AsType(hlo, F32), dim);
}; };
auto broadcast = [&](HloInstruction* hlo, const Shape& shape, auto broadcast = [&](HloInstruction* hlo, const Shape& shape,
@ -1247,8 +1275,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
return dims; return dims;
}; };
// If the contracting dimension is 1, remove the degnerate dimnesions from the // If the contracting dimension is 1, remove the degnerate dimnensions from
// lhs and rhs, broadcast each to the result shape and multiply. // the lhs and rhs, broadcast each to the result shape and multiply.
if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 && if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 &&
(rhs_kept_dim == rhs_rank - 1 || (rhs_kept_dim == rhs_rank - 1 ||
(rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) { (rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) {
@ -1608,34 +1636,26 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
// If there are no contracting dimensions, a dot can be rewritten as // If there are no contracting dimensions, a dot can be rewritten as
// mul(broadcast(transpose(x)),broadcast(transpose(y))) // mul(broadcast(transpose(x)),broadcast(transpose(y)))
if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) { if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) {
std::vector<int64> lhs_transpose( TF_ASSIGN_OR_RETURN(
dot->dot_dimension_numbers().lhs_batch_dimensions().begin(), HloInstruction * new_lhs,
dot->dot_dimension_numbers().lhs_batch_dimensions().end()); NormalizeDotOperandToBatchMajorAndContractingMinor(
for (int64 i = 0; i < lhs->shape().rank(); ++i) { lhs,
if (!absl::c_linear_search( AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
dot->dot_dimension_numbers().lhs_batch_dimensions(), i)) { AsInt64Slice(
lhs_transpose.push_back(i); dot->dot_dimension_numbers().lhs_contracting_dimensions())));
}
}
TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs,
MakeTransposeHlo(lhs, lhs_transpose));
if (dot->shape().rank() != lhs->shape().rank()) { if (dot->shape().rank() != lhs->shape().rank()) {
std::vector<int64> lhs_broadcast_dims(lhs->shape().rank()); std::vector<int64> lhs_broadcast_dims(lhs->shape().rank());
absl::c_iota(lhs_broadcast_dims, 0); absl::c_iota(lhs_broadcast_dims, 0);
new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
dot->shape(), new_lhs, lhs_broadcast_dims)); dot->shape(), new_lhs, lhs_broadcast_dims));
} }
std::vector<int64> rhs_transpose( TF_ASSIGN_OR_RETURN(
dot->dot_dimension_numbers().rhs_batch_dimensions().begin(), HloInstruction * new_rhs,
dot->dot_dimension_numbers().rhs_batch_dimensions().end()); NormalizeDotOperandToBatchMajorAndContractingMinor(
for (int64 i = 0; i < rhs->shape().rank(); ++i) { rhs,
if (!absl::c_linear_search( AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
dot->dot_dimension_numbers().rhs_batch_dimensions(), i)) { AsInt64Slice(
rhs_transpose.push_back(i); dot->dot_dimension_numbers().rhs_contracting_dimensions())));
}
}
TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs,
MakeTransposeHlo(rhs, rhs_transpose));
if (dot->shape().rank() != rhs->shape().rank()) { if (dot->shape().rank() != rhs->shape().rank()) {
std::vector<int64> rhs_broadcast_dims( std::vector<int64> rhs_broadcast_dims(
dot->dot_dimension_numbers().lhs_batch_dimensions_size()); dot->dot_dimension_numbers().lhs_batch_dimensions_size());
@ -1651,6 +1671,78 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
new_lhs, new_rhs)); new_lhs, new_rhs));
} }
// If the lhs or rhs have only batch and contracting dimensions, a dot can be
// rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
lhs->shape().rank()) ||
(dot->dot_dimension_numbers().rhs_contracting_dimensions_size() +
dot->dot_dimension_numbers().rhs_batch_dimensions_size() ==
rhs->shape().rank())) {
TF_ASSIGN_OR_RETURN(
HloInstruction * new_lhs,
NormalizeDotOperandToBatchMajorAndContractingMinor(
lhs,
AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
AsInt64Slice(
dot->dot_dimension_numbers().lhs_contracting_dimensions())));
TF_ASSIGN_OR_RETURN(
HloInstruction * new_rhs,
NormalizeDotOperandToBatchMajorAndContractingMinor(
rhs,
AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
AsInt64Slice(
dot->dot_dimension_numbers().rhs_contracting_dimensions())));
int64 lhs_outer_dims =
lhs->shape().rank() -
(dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
int64 rhs_outer_dims =
rhs->shape().rank() -
(dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
dot->dot_dimension_numbers().rhs_contracting_dimensions_size());
CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0);
if (rhs_outer_dims > 0) {
std::vector<int64> lhs_broadcast_dims(
dot->dot_dimension_numbers().lhs_batch_dimensions_size());
absl::c_iota(lhs_broadcast_dims, 0);
lhs_broadcast_dims.resize(lhs->shape().rank());
std::iota(lhs_broadcast_dims.begin() +
dot->dot_dimension_numbers().lhs_batch_dimensions_size(),
lhs_broadcast_dims.end(),
dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
rhs_outer_dims);
new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
new_rhs->shape(), new_lhs, lhs_broadcast_dims));
} else if (lhs_outer_dims > 0) {
std::vector<int64> rhs_broadcast_dims(
dot->dot_dimension_numbers().rhs_batch_dimensions_size());
absl::c_iota(rhs_broadcast_dims, 0);
rhs_broadcast_dims.resize(rhs->shape().rank());
std::iota(rhs_broadcast_dims.begin() +
dot->dot_dimension_numbers().rhs_batch_dimensions_size(),
rhs_broadcast_dims.end(),
dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
lhs_outer_dims);
new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
new_lhs->shape(), new_rhs, rhs_broadcast_dims));
}
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs));
std::vector<int64> reduce_dims(
dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
new_dot = AsType(new_dot, F32);
const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims);
absl::c_iota(
reduce_dims,
outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size());
new_dot = AddReduce(new_dot, reduce_dims);
new_dot = AsType(new_dot, dot->shape().element_type());
return ReplaceInstruction(dot, new_dot);
}
if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 || if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 ||
dot->shape().rank() > 2) { dot->shape().rank() > 2) {
if (options_.enable_dot_strength_reduction() && if (options_.enable_dot_strength_reduction() &&

View File

@ -2753,8 +2753,9 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) {
Shape keys_shape = ShapeUtil::MakeShape(F32, {1}); Shape keys_shape = ShapeUtil::MakeShape(F32, {1});
auto keys = builder.AddInstruction( auto keys = builder.AddInstruction(
HloInstruction::CreateParameter(0, keys_shape, "keys")); HloInstruction::CreateParameter(0, keys_shape, "keys"));
TF_ASSERT_OK( TF_ASSERT_OK(MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, &builder,
MakeSortHlo(keys_shape, {keys}, 0, &builder, module.get()).status()); module.get())
.status());
HloComputation* computation = module->AddEntryComputation(builder.Build()); HloComputation* computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(default_options_); AlgebraicSimplifier simplifier(default_options_);
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
@ -2775,7 +2776,8 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
HloInstruction::CreateParameter(2, values_shape, "values1")); HloInstruction::CreateParameter(2, values_shape, "values1"));
TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape( TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape(
{keys_shape, values_shape, values_shape}), {keys_shape, values_shape, values_shape}),
{keys, values0, values1}, 0, &builder, module.get()) {keys, values0, values1}, 0, /*is_stable=*/false,
&builder, module.get())
.status()); .status());
HloComputation* computation = module->AddEntryComputation(builder.Build()); HloComputation* computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(default_options_); AlgebraicSimplifier simplifier(default_options_);
@ -3712,8 +3714,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
HloInstruction* y = HloInstruction* y =
builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
DotDimensionNumbers dot_dnums; DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_lhs_batch_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0); dot_dnums.add_rhs_batch_dimensions(0);
builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums, builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums,
DefaultPrecisionConfig(2))); DefaultPrecisionConfig(2)));
std::unique_ptr<HloComputation> dot_computation(builder.Build()); std::unique_ptr<HloComputation> dot_computation(builder.Build());
@ -4220,12 +4222,24 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
int m, k, n; int m, k, n;
PrimitiveType element_type; PrimitiveType element_type;
std::tie(m, k, n, element_type) = GetParam(); std::tie(m, k, n, element_type) = GetParam();
std::vector<int64> lhs_dims = {1, 3, 5};
Shape dot_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, n}); std::vector<int64> rhs_dims = lhs_dims;
Shape lhs_shape = k > 0 ? ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k}) std::vector<int64> output_dims = lhs_dims;
: ShapeUtil::MakeShape(element_type, {1, 3, 5, m}); if (m > 0) {
Shape rhs_shape = k > 0 ? ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n}) lhs_dims.push_back(m);
: ShapeUtil::MakeShape(element_type, {1, 3, 5, n}); output_dims.push_back(m);
}
if (k > 0) {
lhs_dims.push_back(k);
rhs_dims.push_back(k);
}
if (n > 0) {
rhs_dims.push_back(n);
output_dims.push_back(n);
}
Shape dot_shape = ShapeUtil::MakeShape(element_type, output_dims);
Shape lhs_shape = ShapeUtil::MakeShape(element_type, lhs_dims);
Shape rhs_shape = ShapeUtil::MakeShape(element_type, rhs_dims);
HloComputation::Builder builder(TestName()); HloComputation::Builder builder(TestName());
auto lhs = builder.AddInstruction( auto lhs = builder.AddInstruction(
@ -4240,7 +4254,7 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
dot_dnums.add_rhs_batch_dimensions(1); dot_dnums.add_rhs_batch_dimensions(1);
dot_dnums.add_rhs_batch_dimensions(2); dot_dnums.add_rhs_batch_dimensions(2);
if (k > 0) { if (k > 0) {
dot_dnums.add_lhs_contracting_dimensions(4); dot_dnums.add_lhs_contracting_dimensions(m > 0 ? 4 : 3);
dot_dnums.add_rhs_contracting_dimensions(3); dot_dnums.add_rhs_contracting_dimensions(3);
} }
builder.AddInstruction(HloInstruction::CreateDot( builder.AddInstruction(HloInstruction::CreateDot(
@ -4248,9 +4262,9 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
auto computation = module->AddEntryComputation(builder.Build()); auto computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(default_options_); AlgebraicSimplifier simplifier(default_options_);
TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1 || k == -1; const bool dot_should_be_transformed =
const bool computation_should_be_modified = dot_should_be_transformed; m == 1 || k == 1 || n == 1 || m == -1 || k == -1 || n == -1;
EXPECT_EQ(changed, computation_should_be_modified); EXPECT_EQ(changed, dot_should_be_transformed);
bool has_no_dot = true; bool has_no_dot = true;
for (const auto& hlo : computation->instructions()) { for (const auto& hlo : computation->instructions()) {
if (hlo->opcode() == HloOpcode::kDot) { if (hlo->opcode() == HloOpcode::kDot) {
@ -4261,10 +4275,12 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
EXPECT_EQ(has_no_dot, dot_should_be_transformed); EXPECT_EQ(has_no_dot, dot_should_be_transformed);
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(BatchDotStrengthReductionTestInstantiation,
BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest, BatchDotStrengthReductionTest,
::testing::Combine(::testing::Values(1, 2), ::testing::Values(-1, 1, 2), ::testing::Combine(::testing::Values(-1, 1, 2),
::testing::Values(1, 2), ::testing::Values(F32, BF16))); ::testing::Values(-1, 1, 2),
::testing::Values(-1, 1, 2),
::testing::Values(F32, BF16)));
class DotStrengthReductionTest class DotStrengthReductionTest
: public AlgebraicSimplifierTest, : public AlgebraicSimplifierTest,

View File

@ -32,15 +32,13 @@ limitations under the License.
namespace xla { namespace xla {
namespace {
namespace m = match; namespace m = match;
// Checks if the argument instruction is an AllReduce, followed by a certain // Checks if the argument instruction is an AllReduce, followed by a certain
// sequence of instructions and then a CRS. It must be possible to move // sequence of instructions and then a CRS. It must be possible to move
// the AR past each instruction in the sequence. Returns the CRS, which is the // the AR past each instruction in the sequence. Returns the CRS, which is the
// last instruction in the sequence. // last instruction in the sequence.
absl::optional<HloInstruction*> MatchesArCrsPattern( absl::optional<ArCrsCombiner::ArCrsPair> ArCrsCombiner::MatchesArCrsPattern(
HloInstruction* instruction) { HloInstruction* instruction) {
auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool { auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool {
if (instruction->user_count() != 1) { if (instruction->user_count() != 1) {
@ -77,23 +75,23 @@ absl::optional<HloInstruction*> MatchesArCrsPattern(
return absl::nullopt; return absl::nullopt;
} }
auto next = instruction->users()[0]; auto next = instruction->users()[0];
int64 distance = 1;
while (!next->IsCrossReplicaAllReduce()) { while (!next->IsCrossReplicaAllReduce()) {
if (can_ar_move_past_instruction(next)) { if (can_ar_move_past_instruction(next)) {
next = next->users()[0]; next = next->users()[0];
} else { } else {
return absl::nullopt; return absl::nullopt;
} }
++distance;
} }
if (!Cast<HloAllReduceInstruction>(next)->IsNoop() && if (!Cast<HloAllReduceInstruction>(next)->IsNoop() &&
computation_is_addition(next->called_computations()[0])) { computation_is_addition(next->called_computations()[0])) {
return absl::optional<HloInstruction*>(next); return absl::optional<ArCrsPair>(ArCrsPair(instruction, next, distance));
} else { } else {
return absl::nullopt; return absl::nullopt;
} }
} }
} // namespace
absl::optional<HloInstruction*> ArCrsCombiner::WhileFromBodyParameter( absl::optional<HloInstruction*> ArCrsCombiner::WhileFromBodyParameter(
HloInstruction* instruction) { HloInstruction* instruction) {
CHECK_EQ(HloOpcode::kParameter, instruction->opcode()); CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
@ -235,15 +233,55 @@ bool ArCrsCombiner::InstructionsComputeSameValue(
} }
void ArCrsCombiner::GroupAllReducesById(HloModule* module) { void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
// Say that two or more ARs lead to the same CRS: (AR1, CRS), (AR2, CRS),
// ... , (ARn, CRS).
// If as we traverse the HLO graph we start tracking the pair (AR2, CRS),
// and later find that AR1's distance from the CRS is longer, we discard
// AR2 and start tracking AR1. We put the discarded ids in this set, in order
// to skip processing of short paths when we encounter the other ARs that
// have the same id as AR2.
absl::flat_hash_set<int64> discarded_ar_ids;
for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloComputation* computation : module->MakeNonfusionComputations()) {
for (HloInstruction* instruction : computation->instructions()) { for (HloInstruction* instruction : computation->instructions()) {
auto maybe_crs = MatchesArCrsPattern(instruction); auto maybe_pair = MatchesArCrsPattern(instruction);
if (maybe_crs) { if (maybe_pair) {
auto crs = *maybe_crs; auto pair = *maybe_pair;
int64 ar_id = *(instruction->all_reduce_id()); int64 ar_id = *(instruction->all_reduce_id());
if (crs_reserved_map_.find(crs) == crs_reserved_map_.end()) { if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) {
all_reduce_map_[ar_id].push_back(instruction); continue;
crs_reserved_map_[crs] = ar_id; }
auto it = crs_reserved_map_.find(pair.crs);
if (it != crs_reserved_map_.end()) {
auto prev_ar_id = it->second;
// Since there is another AR paired with CRS,
// all_reduce_map_[prev_ar_id] should exist, but
// all_reduce_map_[ar_id] shouldn't.
CHECK(all_reduce_map_.find(ar_id) == all_reduce_map_.end());
CHECK_NE(prev_ar_id, ar_id);
auto prev_pair = all_reduce_map_[prev_ar_id].back();
int64 prev_distance = prev_pair.distance;
if (prev_distance < pair.distance) {
// The current AR's distance to CRS is longer than the previously
// tracked AR, so we discard the previous AR.
all_reduce_map_.erase(prev_ar_id);
discarded_ar_ids.insert(prev_ar_id);
all_reduce_map_[ar_id].push_back(pair);
crs_reserved_map_[pair.crs] = ar_id;
} else {
// Discard the current AR id because we are keeping the previously
// tracked AR.
discarded_ar_ids.insert(ar_id);
}
} else {
if (all_reduce_map_.find(ar_id) != all_reduce_map_.end()) {
int64 prev_distance = all_reduce_map_[ar_id].back().distance;
CHECK_EQ(prev_distance, pair.distance)
<< "All ARs with the same AR ID must have the same distance "
"from the corresponding CRSs. Found: "
<< prev_distance << " and " << pair.distance;
}
all_reduce_map_[ar_id].push_back(pair);
crs_reserved_map_[pair.crs] = ar_id;
} }
} }
} }
@ -253,11 +291,11 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
for (auto it : all_reduce_map_) { for (auto it : all_reduce_map_) {
auto all_reduce_id = it.first; auto all_reduce_id = it.first;
auto instruction_vec = it.second; auto pairs_vec = it.second;
CHECK_EQ(instruction_vec.size(), num_spatial_partitions_); CHECK_EQ(pairs_vec.size(), num_spatial_partitions_);
auto instr_0 = instruction_vec[0]; auto instr_0 = pairs_vec[0].ar;
for (int i = 1; i < instruction_vec.size(); ++i) { for (int i = 1; i < pairs_vec.size(); ++i) {
auto instr_i = instruction_vec[i]; auto instr_i = pairs_vec[i].ar;
auto next_0 = instr_0->users()[0]; auto next_0 = instr_0->users()[0];
auto next_i = instr_i->users()[0]; auto next_i = instr_i->users()[0];
absl::flat_hash_map<int64, int64> visited_pairs; absl::flat_hash_map<int64, int64> visited_pairs;
@ -281,8 +319,9 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
return false; return false;
} }
for (auto it : all_reduce_map_) { for (auto it : all_reduce_map_) {
auto instruction_vec = it.second; auto pairs_vec = it.second;
for (auto all_reduce : instruction_vec) { for (auto pair : pairs_vec) {
auto all_reduce = pair.ar;
auto parent_computation = all_reduce->parent(); auto parent_computation = all_reduce->parent();
auto all_reduce_id = all_reduce->all_reduce_id(); auto all_reduce_id = all_reduce->all_reduce_id();
auto prev = all_reduce->mutable_operand(0); auto prev = all_reduce->mutable_operand(0);
@ -303,16 +342,23 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
? next->operands()[1] ? next->operands()[1]
: next->operands()[0]; : next->operands()[0];
// To move the AR past the addition/subtraction, we need to divide // To move the AR past the addition/subtraction, we need to divide
// other_operand by the number of spatial partitions. // other_operand by the number of spatial partitions, except if
// other_operand is a cross-module AR, which can be eliminated.
if (other_operand->IsCrossModuleAllReduce() &&
other_operand->user_count() == 1) {
TF_CHECK_OK(other_operand->ReplaceAllUsesWith(
other_operand->mutable_operand(0)));
} else {
auto shape = other_operand->shape(); auto shape = other_operand->shape();
Literal lit(shape); Literal lit(shape);
lit.PopulateWithValue<float>(num_spatial_partitions_); lit.PopulateWithValue<float>(num_spatial_partitions_);
auto divisor = parent_computation->AddInstruction( auto divisor = parent_computation->AddInstruction(
HloInstruction::CreateConstant(lit.Clone())); HloInstruction::CreateConstant(lit.Clone()));
auto division = auto division = parent_computation->AddInstruction(
parent_computation->AddInstruction(HloInstruction::CreateBinary( HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
shape, HloOpcode::kDivide, other_operand, divisor)); other_operand, divisor));
TF_CHECK_OK(other_operand->ReplaceUseWith(next, division)); TF_CHECK_OK(other_operand->ReplaceUseWith(next, division));
}
break; break;
} }
default: default:

View File

@ -26,11 +26,47 @@ limitations under the License.
namespace xla { namespace xla {
// When the HLO graph contains a cross-module AllReduce, followed by some simple // When the HLO graph contains a cross-module AllReduce, followed by some simple
// linear operations, followed by a cross-replica AllReduce, we can combine the // linear operations, followed by a cross-replica AllReduce (also known as
// CMAR and the CRAR, to use an efficient AllReduce implementation that fully // cross-replica sum, or CRS), we can combine the CMAR and the CRAR, to use an
// utilizes the interconnect bandwidth. // efficient AllReduce implementation that fully utilizes the interconnect
// bandwidth.
// Such sequences appear in spatially partitioned models. // Such sequences appear in spatially partitioned models.
// This pass must run right after spatial partitioning. // This pass must run right after spatial partitioning, when the code is still
// in a single HLO module.
//
// The steps are:
// 1) Find CMARs followed by simple ops followed by CRARs.
// 2) Group CMARs by all_reduce_id. They must all be rewritten.
// 3) Prove that the CMAR patterns in each core produce the same result.
// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the
// other operand by the number of spatial partitions.
// 5) Turn the CRAR into an all-core AllReduce.
//
// The pass also handles the case where multiple CMARs lead to the same CRAR,
// and eliminates all CMARs. This graph:
//
// Y
// |
// X CMAR_2 Z
// | \ /
// CMAR_1 +
// \ /
// +
// |
// CRAR
//
// gets rewritten to:
//
// Z num_partitions
// \ /
// Y div
// \ /
// X +
// \ /
// +
// |
// all-core AR
//
class ArCrsCombiner : public HloModulePass { class ArCrsCombiner : public HloModulePass {
public: public:
ArCrsCombiner(int num_spatial_partitions) ArCrsCombiner(int num_spatial_partitions)
@ -43,6 +79,28 @@ class ArCrsCombiner : public HloModulePass {
HloInstruction* i2); HloInstruction* i2);
private: private:
// We used this struct because multiple ARs could be paired with the same CRS.
// In this case, we want to select the AR that is furthest from the CRS,
// because it makes it easier to eliminate all ARs during RewriteGraph.
struct ArCrsPair {
HloInstruction* ar;
HloInstruction* crs;
// The length of the path from AR to CRS in the HLO graph.
int64 distance;
ArCrsPair(HloInstruction* all_reduce, HloInstruction* cross_replica_sum,
int64 dist)
: ar(all_reduce), crs(cross_replica_sum), distance(dist) {}
string ToString() {
return absl::StrCat("(AR: ", ar->name(), ", CRS: ", crs->name(),
", distance: ", distance, ")");
}
};
absl::optional<ArCrsCombiner::ArCrsPair> MatchesArCrsPattern(
HloInstruction* instruction);
// If the passed instruction is a while parameter, and the while body is only // If the passed instruction is a while parameter, and the while body is only
// called by a single while instruction, return the while instruction. // called by a single while instruction, return the while instruction.
absl::optional<HloInstruction*> WhileFromBodyParameter( absl::optional<HloInstruction*> WhileFromBodyParameter(
@ -80,8 +138,8 @@ class ArCrsCombiner : public HloModulePass {
int num_spatial_partitions_; int num_spatial_partitions_;
// Map from all-reduce ids to the all reduce instructions. // Map from all-reduce ids to the AR/CRS pairs.
absl::flat_hash_map<int64, std::vector<HloInstruction*>> all_reduce_map_; absl::flat_hash_map<int64, std::vector<ArCrsPair>> all_reduce_map_;
// Map from a CRS instruction to the all-reduce ID of the AR paired with the // Map from a CRS instruction to the all-reduce ID of the AR paired with the
// CRS. Sometimes, several ARs in the code could be paired with the same CRS. // CRS. Sometimes, several ARs in the code could be paired with the same CRS.

View File

@ -1005,11 +1005,11 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
op::Tuple(op::AllReduce(op::Add( op::Tuple(op::AllReduce(op::Add(
op::Add(op::Parameter(), op::Add(op::Parameter(),
op::Divide(op::Constant(), op::Constant())), op::Divide(op::Constant(), op::Constant())),
op::Divide(op::AllReduce(), op::Constant()))), op::Parameter())),
op::AllReduce(op::Add( op::AllReduce(op::Add(
op::Add(op::Parameter(), op::Add(op::Parameter(),
op::Divide(op::Constant(), op::Constant())), op::Divide(op::Constant(), op::Constant())),
op::Divide(op::AllReduce(), op::Constant()))))); op::Parameter()))));
auto crs_after = auto crs_after =
module->entry_computation()->root_instruction()->operands()[0]; module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_after = crs_after->replica_groups(); auto replica_groups_after = crs_after->replica_groups();
@ -1093,15 +1093,17 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
ArCrsCombiner combiner(2); ArCrsCombiner combiner(2);
auto changed = combiner.Run(module.get()).ValueOrDie(); auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed); EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(), EXPECT_THAT(
module->entry_computation()->root_instruction(),
op::Tuple(op::AllReduce(op::Add( op::Tuple(op::AllReduce(op::Add(
op::Parameter(), op::Parameter(),
op::Divide(op::Add(op::AllReduce(), op::Constant()), op::Add(op::Parameter(),
op::Constant()))), op::Divide(op::Constant(), op::Constant())))),
op::AllReduce(op::Add( op::AllReduce(op::Add(
op::Parameter(), op::Parameter(),
op::Divide(op::Add(op::AllReduce(), op::Constant()), op::Add(op::Parameter(),
op::Constant()))))); op::Divide(op::Constant(), op::Constant()))))));
auto crs_after = auto crs_after =
module->entry_computation()->root_instruction()->operands()[0]; module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_after = crs_after->replica_groups(); auto replica_groups_after = crs_after->replica_groups();

View File

@ -286,7 +286,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto* sort, auto* sort,
MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}),
{key, value}, 0, &builder, module.get())); {key, value}, 0, /*is_stable=*/false, &builder,
module.get()));
HloInstruction* gte = builder.AddInstruction( HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0));
@ -314,7 +315,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) {
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto* sort, auto* sort,
MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}), MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}),
{key, value}, 0, &builder, module.get())); {key, value}, 0, /*is_stable=*/false, &builder,
module.get()));
auto computation = module->AddEntryComputation(builder.Build()); auto computation = module->AddEntryComputation(builder.Build());

View File

@ -673,9 +673,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
if (embed_ir_in_executable) { if (embed_ir_in_executable) {
ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
} }
TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module));
TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
// JIT compile the LLVM IR module to in-memory machine code. // JIT compile the LLVM IR module to in-memory machine code.
jit->AddModule(std::move(llvm_module)); jit->AddModule(std::move(llvm_module));

View File

@ -963,8 +963,8 @@ Status EmitBatchDotOperation(
KernelSupportLibrary ksl(b); KernelSupportLibrary ksl(b);
return ksl.ForWithStatus( return ksl.ForWithStatus(
"bdot", /*start=*/0, /*end=*/batch_count, /*step=*/1, llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count,
[&](llvm::Value* indvar) { /*step=*/1, [&](llvm::Value* indvar) {
DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers(); DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
adjusted_dim_numbers.clear_lhs_batch_dimensions(); adjusted_dim_numbers.clear_lhs_batch_dimensions();
adjusted_dim_numbers.clear_rhs_batch_dimensions(); adjusted_dim_numbers.clear_rhs_batch_dimensions();

View File

@ -583,7 +583,7 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) {
b_.getVoidTy(), b_.getVoidTy(),
{b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), {b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(), b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(),
b_.getInt32Ty()->getPointerTo(), b_.getInt8PtrTy(), b_.getInt32Ty()->getPointerTo(), b_.getInt1Ty(), b_.getInt8PtrTy(),
b_.getInt64Ty()->getPointerTo(), less_than_function->getType()}, b_.getInt64Ty()->getPointerTo(), less_than_function->getType()},
/*isVarArg=*/false); /*isVarArg=*/false);
auto* key_value_sort_func = llvm::dyn_cast<llvm::Function>( auto* key_value_sort_func = llvm::dyn_cast<llvm::Function>(
@ -616,8 +616,8 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) {
{b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
b_.getInt64(lower_dimensions), values, b_.getInt64(lower_dimensions), values,
b_.getInt32(sort->operand_count()), sizes, b_.getInt32(sort->operand_count()), sizes,
GetExecutableRunOptionsArgument(), GetProfileCountersArgument(), b_.getInt1(sort->is_stable()), GetExecutableRunOptionsArgument(),
less_than_function}); GetProfileCountersArgument(), less_than_function});
if (sort->values_count() > 0) { if (sort->values_count() > 0) {
llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_, llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_,

View File

@ -32,8 +32,8 @@ using tensorflow::int64;
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort(
int64 a, int64 b, int64 c, char** values, int32 values_count, int64 a, int64 b, int64 c, char** values, int32 values_count,
int32* values_primitive_type_size_in_bytes, char* run_options, int32* values_primitive_type_size_in_bytes, bool is_stable,
int64* prof_counters, char* run_options, int64* prof_counters,
void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)) { void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)) {
// 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT // 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT
// code, so msan can't tell they are initialized. // code, so msan can't tell they are initialized.
@ -69,9 +69,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort(
int64 base_offset = int64 base_offset =
index % sort_dimension_offset + index % sort_dimension_offset +
(index - index % sort_dimension_offset) * sort_dimension_elements; (index - index % sort_dimension_offset) * sort_dimension_elements;
std::stable_sort( auto compare_function = [&](int64 a, int64 b) -> bool {
indices.get(), indices.get() + sort_dimension_elements,
[&](int64 a, int64 b) -> bool {
int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) *
values_primitive_type_size_in_bytes[0]; values_primitive_type_size_in_bytes[0];
int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) *
@ -84,7 +82,14 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort(
less_than(&result, run_options, comparison_values.get(), nullptr, less_than(&result, run_options, comparison_values.get(), nullptr,
prof_counters); prof_counters);
return result != 0u; return result != 0u;
}); };
if (is_stable) {
std::stable_sort(indices.get(), indices.get() + sort_dimension_elements,
compare_function);
} else {
std::sort(indices.get(), indices.get() + sort_dimension_elements,
compare_function);
}
// Reorder the values according to the order defined by 'indices'. // Reorder the values according to the order defined by 'indices'.
for (int32 idx = 0; idx < values_count; ++idx) { for (int32 idx = 0; idx < values_count; ++idx) {

View File

@ -22,15 +22,14 @@ limitations under the License.
extern "C" { extern "C" {
// Each entry in 'values' represents a 3-dimensional shape with dimensions // Each entry in 'values' represents a 3-dimensional shape with dimensions
// [a, b, c]. The 'b' dimension of the first shape is sorted into ascending // [a, b, c]. The 'b' dimension of each shape is sorted into ascending order
// order according to the results of comparisons using the provided 'less_than' // according to the results of comparisons using the provided 'less_than'
// function. 'values_count' must be > 0 and specifies the number of entries in // function. 'values_count' must be > 0 and specifies the number of entries in
// 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive // 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive
// type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]' // type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]'
// bytes. The elements in each 'values' shape are reordered in the same way // bytes. 'is_stable' specifies whether the sorting should be stable.
// according to the comparisons using the first shape. 'run_options' and // 'run_options' and 'prof_counters' are passed through to the less-than
// 'prof_counters' are passed through to the less-than function, which expects // function, which expects the following arguments:
// the following arguments:
// - pointer to the return value buffer (char*) // - pointer to the return value buffer (char*)
// - xla::ExecutableRunOptions = 'run_options' (char*) // - xla::ExecutableRunOptions = 'run_options' (char*)
// - pointers to the parameter buffers (char**) // - pointers to the parameter buffers (char**)
@ -39,8 +38,8 @@ extern "C" {
extern void __xla_cpu_runtime_KeyValueSort( extern void __xla_cpu_runtime_KeyValueSort(
tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
char** values, tensorflow::int32 values_count, char** values, tensorflow::int32 values_count,
tensorflow::int32* values_primitive_type_size_in_bytes, char* run_options, tensorflow::int32* values_primitive_type_size_in_bytes, bool is_stable,
tensorflow::int64* prof_counters, char* run_options, tensorflow::int64* prof_counters,
void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)); void (*less_than)(char*, char*, char**, char**, tensorflow::int64*));
} }

View File

@ -938,6 +938,53 @@ void TiledSmallGemmEmitter::EmitTiledGemm(
}); });
} }
llvm::Type* GetPointerToElementType(llvm::Type* pointer_type) {
llvm::Type* type =
llvm::cast<llvm::PointerType>(pointer_type)->getElementType();
while (auto* array_type = llvm::dyn_cast<llvm::ArrayType>(type)) {
type = array_type->getElementType();
}
return type->getPointerTo();
}
struct GemvBuffersWithCanonicalType {
llvm::Value* lhs_canonicalized;
llvm::Value* rhs_canonicalized;
llvm::Value* addend_canonicalized;
llvm::Value* result_canonicalized;
};
GemvBuffersWithCanonicalType GetGemvBuffersWithCanonicalType(
llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend,
llvm::Value* result, llvm::IRBuilder<>* b) {
// We characterize a GEMV operation via M and K, since N is implicitly 1.
// This means the GEMV that multiplies (say) [5,6] with [6,1] is implemented
// by the same GEMV that multiplies [5,6] with [1,6]. However, the
// `llvm::Types` for the inputs to the two GEMVs don't match (in a trivial
// sense -- the in memory representations are the same) since they're computed
// from the `xla::Shape`s. Since we want to be able to call the same
// `llvm::Function` for the two GEMVs we canonicalize the types of the GEMV
// inputs here into the same type.
GemvBuffersWithCanonicalType buffers_with_canonical_type;
llvm::Type* lhs_type = lhs->getType();
llvm::Type* rhs_type = rhs->getType();
llvm::Type* addend_type = addend ? addend->getType() : nullptr;
llvm::Type* result_type = result->getType();
buffers_with_canonical_type.lhs_canonicalized =
b->CreateBitCast(lhs, GetPointerToElementType(lhs_type));
buffers_with_canonical_type.rhs_canonicalized =
b->CreateBitCast(rhs, GetPointerToElementType(rhs_type));
buffers_with_canonical_type.addend_canonicalized =
addend ? b->CreateBitCast(addend, GetPointerToElementType(addend_type))
: nullptr;
buffers_with_canonical_type.result_canonicalized =
b->CreateBitCast(result, GetPointerToElementType(result_type));
return buffers_with_canonical_type;
}
} // namespace } // namespace
void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows, void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
@ -950,11 +997,17 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
/*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
/*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
GemvBuffersWithCanonicalType canonical_inputs =
GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
KernelSupportLibrary::EmitAndCallOutlinedKernel( KernelSupportLibrary::EmitAndCallOutlinedKernel(
/*enable_fast_math=*/enable_fast_math, /*enable_fast_math=*/enable_fast_math,
/*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(),
rhs, addend, result, canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
[&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, canonical_inputs.addend_canonicalized,
canonical_inputs.result_canonicalized,
[&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs,
llvm::Value* addend,
llvm::Value* result) { llvm::Value* result) {
RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
result, b); result, b);
@ -972,11 +1025,17 @@ void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
/*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
/*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr);
GemvBuffersWithCanonicalType canonical_inputs =
GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
KernelSupportLibrary::EmitAndCallOutlinedKernel( KernelSupportLibrary::EmitAndCallOutlinedKernel(
/*enable_fast_math=*/enable_fast_math, /*enable_fast_math=*/enable_fast_math,
/*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(),
rhs, addend, result, canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
[&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, canonical_inputs.addend_canonicalized,
canonical_inputs.result_canonicalized,
[&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs,
llvm::Value* addend,
llvm::Value* result) { llvm::Value* result) {
ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend,
result, b); result, b);

View File

@ -1367,26 +1367,69 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
llvm_ir::PrimitiveTypeToIrType(elem_prim_ty, module_); llvm_ir::PrimitiveTypeToIrType(elem_prim_ty, module_);
llvm::Type* raw_value_ty = raw_value->getType(); llvm::Type* raw_value_ty = raw_value->getType();
// Convert raw integer to float in range [0, 1) if the element is a float. // If we're generating a floating-point value, convert the raw integer R (i.e.
// `raw_value`) to a float in the range [0, 1).
//
// The basic approach is to choose a significand and exponent such that the
// significand is uniformly distributed and the exponent is distributed, well,
// exponentially (it's more likely to be close to 0 than far from 0).
//
// An easy way to do this is to say that the significand is the first S bits
// of R, and the exponent is determined by the number of trailing zeroes in R,
// exp = 2^-(cttz(R) + 1). (+1 because the largest exponent should be -1;
// this way the largest value we can return is 1.999... * 2^-1 = 1-ε.)
//
// This results in a small bias. Namely, if R has enough trailing zeroes, the
// significand and exponent will "overlap". As a concrete example, consider
//
// 20 X's 12 zeroes
// R = 0bXXXXXXXXXXXXXXXXXXXX000000000000
//
// Here the exponent is 2^-13 because R has 12 trailing zeroes. The
// significand is made up of the first 23 most-significant bits of R, which we
// observe contain 3 zeroes. This is biased because any random value with
// exponent 2^-12 will have a significand which ends in `000`.
//
// For f32s, this problem occurs only when there are more than 32-23 = 9
// trailing zeros, which happens with probability 0.5^10 = ~0.1%. Moreover the
// probability of a large bias (i.e. many trailing 0s in the significand) is
// exponentially low. So we deem this acceptable.
llvm::Value* elem_value = raw_value; llvm::Value* elem_value = raw_value;
if (elem_ir_ty->isFloatingPointTy()) { if (elem_ir_ty->isFloatingPointTy()) {
unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits(); const auto& dest_flt_semantics = elem_ir_ty->getFltSemantics();
CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64); const int bits = raw_value_ty->getPrimitiveSizeInBits();
// Perform the division using the float type with the same number of bits CHECK_GE(bits, llvm::APFloat::semanticsSizeInBits(dest_flt_semantics));
// as the raw value to avoid overflow.
if (raw_value_size_in_bits == 32) {
elem_value = UIToFP(elem_value, b_->getFloatTy());
elem_value = FDiv(elem_value,
llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32)));
} else {
elem_value = UIToFP(elem_value, b_->getDoubleTy());
elem_value = FDiv(
elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64)));
}
if (elem_ir_ty != elem_value->getType()) { // Subtract 1 because semanticsPrecision includes the "hidden bit", i.e. the
elem_value = FPTrunc(elem_value, elem_ir_ty); // implicit "1." at the beginning of the significand.
} const int significand_bits =
llvm::APFloat::semanticsPrecision(dest_flt_semantics) - 1;
llvm::Value* cttz = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::cttz, {raw_value, /*is_zero_undef=*/b_->getFalse()},
{raw_value->getType()}, b_);
llvm::Value* significand = LShr(raw_value, bits - significand_bits);
// Exponent bias is -127 for f32, meaning that if the exponent is E and the
// significand is S, then the value of the number is 2^(E - 127) * (1.S).
//
// We want cttz == 0 to correspond to 2^-1, so our exponent is computed as
// E = 126 - cttz.
//
// For f64, this is all the same, except the bias is -1023.
//
// In IEEE floating point, the absolute value of the exponent bias equals
// the value of the largest possible exponent.
const int bias = -llvm::APFloat::semanticsMaxExponent(dest_flt_semantics);
llvm::Value* exponent =
Sub(llvm::ConstantInt::get(cttz->getType(), -bias - 1), cttz);
// Now just slot everything into place! The `Trunc` is here because
// raw_value may be larger than our float destination.
elem_value =
BitCast(Trunc(Or(Shl(exponent, significand_bits), significand),
b_->getIntNTy(elem_ir_ty->getPrimitiveSizeInBits())),
elem_ir_ty);
} }
// Convert the value for the requested distribution. // Convert the value for the requested distribution.

View File

@ -216,8 +216,11 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( llvm_ir::ElementGenerator MakePhiloxRngElementGenerator(
const HloInstruction* hlo, const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator); const HloToElementGeneratorMap& operand_to_generator);
// Converts the raw value generated by a random number generation algorithm // Converts the raw value generated by a random number generation algorithm
// to the distribution requested by the RNG HloInstruction. // to the distribution requested by the RNG HloInstruction.
//
// Precondition: raw_value has at least as many bits as hlo's element type.
StatusOr<llvm::Value*> ConvertValueForDistribution( StatusOr<llvm::Value*> ConvertValueForDistribution(
const HloInstruction* hlo, const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,

View File

@ -765,6 +765,7 @@ cc_library(
"//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:sort_simplifier", "//tensorflow/compiler/xla/service:sort_simplifier",
"//tensorflow/compiler/xla/service:stable_sort_expander",
"//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/compiler/xla/service:while_loop_constant_sinking", "//tensorflow/compiler/xla/service:while_loop_constant_sinking",

View File

@ -82,6 +82,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/sort_simplifier.h" #include "tensorflow/compiler/xla/service/sort_simplifier.h"
#include "tensorflow/compiler/xla/service/stable_sort_expander.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
@ -195,6 +196,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
pipeline.AddPass<ConvolutionGroupConverter>( pipeline.AddPass<ConvolutionGroupConverter>(
cost_model, cost_model,
/*convert_batch_groups_only=*/true); /*convert_batch_groups_only=*/true);
// Expand the sort op to support stable sorting if required.
pipeline.AddPass<StableSortExpander>();
// Convert BF16 operations to F32 operations so that the GPU backend can // Convert BF16 operations to F32 operations so that the GPU backend can
// support BF16 operations without directly implementing a BF16 lowering for // support BF16 operations without directly implementing a BF16 lowering for
// most ops. // most ops.

View File

@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true; option cc_enable_arenas = true;
// Serialization of HloInstruction. // Serialization of HloInstruction.
// Next ID: 60 // Next ID: 61
message HloInstructionProto { message HloInstructionProto {
reserved 10; reserved 10;
reserved "parameter_name"; reserved "parameter_name";
@ -175,6 +175,9 @@ message HloInstructionProto {
// partners. // partners.
bool is_host_transfer = 47; bool is_host_transfer = 47;
// Whether this Sort instruction should be stable.
bool is_stable = 60;
xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;
// Precision configuration for the instruction. Has backend-specific meaning. // Precision configuration for the instruction. Has backend-specific meaning.

View File

@ -275,7 +275,7 @@ StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
StatusOr<HloInstruction*> MakeSortHlo( StatusOr<HloInstruction*> MakeSortHlo(
const Shape& sort_shape, absl::Span<HloInstruction* const> operands, const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
int64 dimension_to_sort, HloComputation::Builder* builder, int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
HloModule* module) { HloModule* module) {
CHECK(!operands.empty()) << "Sort Hlo requires at least one operand."; CHECK(!operands.empty()) << "Sort Hlo requires at least one operand.";
HloComputation* compare_computation; HloComputation* compare_computation;
@ -293,7 +293,7 @@ StatusOr<HloInstruction*> MakeSortHlo(
compare_computation = compare_computation =
module->DeepCloneComputation(new_module->entry_computation(), &context); module->DeepCloneComputation(new_module->entry_computation(), &context);
return builder->AddInstruction(HloInstruction::CreateSort( return builder->AddInstruction(HloInstruction::CreateSort(
sort_shape, dimension_to_sort, operands, compare_computation)); sort_shape, dimension_to_sort, operands, compare_computation, is_stable));
} }
StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) { StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {

View File

@ -126,10 +126,10 @@ StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
// Creates a Sort HLO instruction and adds it to the computation containing the // Creates a Sort HLO instruction and adds it to the computation containing the
// operands. All operands must be in the same computation. Also creates a // operands. All operands must be in the same computation. Also creates a
// default compare sub-computation which sorts the first operand into ascending // default compare sub-computation which sorts the first operand into ascending
// order. // order. 'is_stable' specifies whether the sorting should be stable.
StatusOr<HloInstruction*> MakeSortHlo( StatusOr<HloInstruction*> MakeSortHlo(
const Shape& sort_shape, absl::Span<HloInstruction* const> operands, const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
int64 dimension_to_sort, HloComputation::Builder* builder, int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
HloModule* module); HloModule* module);
// Creates an R1 Constant HLO instruction of the given PrimitiveType with the // Creates an R1 Constant HLO instruction of the given PrimitiveType with the

View File

@ -2363,7 +2363,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
auto keys = builder.AddInstruction( auto keys = builder.AddInstruction(
HloInstruction::CreateParameter(0, keys_shape, "keys")); HloInstruction::CreateParameter(0, keys_shape, "keys"));
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto* sort, MakeSortHlo(keys_shape, {keys}, -1, &builder, module_.get())); auto* sort, MakeSortHlo(keys_shape, {keys}, -1, /*is_stable=*/false,
&builder, module_.get()));
computation_ = module_->AddEntryComputation(builder.Build()); computation_ = module_->AddEntryComputation(builder.Build());
RunAnalysis(); RunAnalysis();
@ -2385,7 +2386,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto* sort, auto* sort,
MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}), MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
{keys, values}, 0, &builder, module_.get())); {keys, values}, 0, /*is_stable=*/false, &builder,
module_.get()));
computation_ = module_->AddEntryComputation(builder.Build()); computation_ = module_->AddEntryComputation(builder.Build());
RunAnalysis(); RunAnalysis();

View File

@ -2670,12 +2670,25 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& high = const Literal& high =
parent_->GetEvaluatedLiteralFor(random->operand(1)); parent_->GetEvaluatedLiteralFor(random->operand(1));
std::uniform_real_distribution<NativeT> generator( // std::uniform_real_distribution(a, b) can sometimes return a value
low.Get<NativeT>({}), high.Get<NativeT>({})); // equal to b. Unclear if this is a spec bug or an implementation bug
// or WAI [0] [1] [2]. Anyway for our purposes we want a half-open
// interval, so we have to re-sample if we get `b` out.
//
// [0] https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176
// [1] https://bugs.llvm.org/show_bug.cgi?id=18767
// [2] http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524
auto low_val = low.Get<NativeT>({});
auto high_val = high.Get<NativeT>({});
std::uniform_real_distribution<NativeT> generator(low_val, high_val);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) { result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) {
return generator(parent_->engine_); while (true) {
NativeT v = generator(parent_->engine_);
if (v != high_val) {
return v;
}
}
})); }));
break; break;
} }

View File

@ -214,7 +214,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< proto.called_computation_ids_size(); << proto.called_computation_ids_size();
auto sort_operands = all_operands(); auto sort_operands = all_operands();
instruction = CreateSort(shape, proto.dimensions(0), all_operands(), instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
computations(0)); computations(0), proto.is_stable());
break; break;
} }
case HloOpcode::kTranspose: case HloOpcode::kTranspose:
@ -1170,9 +1170,10 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
const Shape& shape, int64 dimension, const Shape& shape, int64 dimension,
absl::Span<HloInstruction* const> operands, HloComputation* compare) { absl::Span<HloInstruction* const> operands, HloComputation* compare,
bool is_stable) {
return absl::make_unique<HloSortInstruction>(shape, dimension, operands, return absl::make_unique<HloSortInstruction>(shape, dimension, operands,
compare); compare, is_stable);
} }
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(

View File

@ -384,6 +384,14 @@ class HloInstruction {
// Creates a random number generation instruction that fills a shape with // Creates a random number generation instruction that fills a shape with
// random numbers from a given distribution. // random numbers from a given distribution.
//
// The parameters to the instruction are interpreted as follows:
//
// - If `distribution` is RNG_UNIFORM, generates a number in range
// [param0, param1).
//
// - If `distribution` is RNG_NORMAL, generates a normally-distributed value
// with mean `param0` and standard deviation `param1`.
static std::unique_ptr<HloInstruction> CreateRng( static std::unique_ptr<HloInstruction> CreateRng(
const Shape& shape, RandomDistribution distribution, const Shape& shape, RandomDistribution distribution,
absl::Span<HloInstruction* const> parameters); absl::Span<HloInstruction* const> parameters);
@ -678,10 +686,11 @@ class HloInstruction {
// comparisons in the sorting algorithm. 'compare' gets 2 * n parameters, // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters,
// where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at
// specific index positions which should be compared, and should return a // specific index positions which should be compared, and should return a
// PRED. // PRED. 'is_stable' specifies whether stable sorting is required.
static std::unique_ptr<HloInstruction> CreateSort( static std::unique_ptr<HloInstruction> CreateSort(
const Shape& shape, int64 dimension, const Shape& shape, int64 dimension,
absl::Span<HloInstruction* const> operands, HloComputation* compare); absl::Span<HloInstruction* const> operands, HloComputation* compare,
bool is_stable);
// Creates a while instruction, given a condition computation, a body // Creates a while instruction, given a condition computation, a body
// computation, and the initial value for the input of the computations. For // computation, and the initial value for the input of the computations. For
@ -1286,6 +1295,9 @@ class HloInstruction {
backend_config_ = std::move(config_str); backend_config_ = std::move(config_str);
} }
bool is_default_config() const { return is_default_config_; }
void set_default_config() { is_default_config_ = true; }
// Returns a string representation of a proto in the format used by // Returns a string representation of a proto in the format used by
// raw_backend_config_string. // raw_backend_config_string.
// //
@ -1734,6 +1746,10 @@ class HloInstruction {
// HLO. See the documentation on backend_config(). // HLO. See the documentation on backend_config().
string backend_config_; string backend_config_;
// This field is assigned to true when backend_config_ is assigned to
// a default configuration.
bool is_default_config_ = false;
// String identifier for instruction. // String identifier for instruction.
string name_; string name_;

View File

@ -659,8 +659,11 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
HloSortInstruction::HloSortInstruction( HloSortInstruction::HloSortInstruction(
const Shape& shape, int64 dimension, const Shape& shape, int64 dimension,
absl::Span<HloInstruction* const> operands, HloComputation* compare) absl::Span<HloInstruction* const> operands, HloComputation* compare,
: HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) { bool is_stable)
: HloInstruction(HloOpcode::kSort, shape),
dimensions_({dimension}),
is_stable_(is_stable) {
for (auto* value : operands) { for (auto* value : operands) {
AppendOperand(value); AppendOperand(value);
} }
@ -672,12 +675,18 @@ HloInstructionProto HloSortInstruction::ToProto() const {
for (int64 dimension : dimensions_) { for (int64 dimension : dimensions_) {
proto.add_dimensions(dimension); proto.add_dimensions(dimension);
} }
proto.set_is_stable(is_stable());
return proto; return proto;
} }
std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl( std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const { const HloPrintOptions& options) const {
return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; std::vector<string> attrs;
attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}"));
if (is_stable()) {
attrs.push_back("is_stable=true");
}
return attrs;
} }
bool HloSortInstruction::IdenticalSlowPath( bool HloSortInstruction::IdenticalSlowPath(
@ -688,14 +697,17 @@ bool HloSortInstruction::IdenticalSlowPath(
if (dimensions() != casted_other.dimensions()) { if (dimensions() != casted_other.dimensions()) {
return false; return false;
} }
if (is_stable() != casted_other.is_stable()) {
return false;
}
return eq_computations(to_apply(), other.to_apply()); return eq_computations(to_apply(), other.to_apply());
} }
std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl( std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands, const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const { HloCloneContext* context) const {
return absl::make_unique<HloSortInstruction>(shape, dimensions(0), return absl::make_unique<HloSortInstruction>(
new_operands, to_apply()); shape, dimensions(0), new_operands, to_apply(), is_stable());
} }
HloTransposeInstruction::HloTransposeInstruction( HloTransposeInstruction::HloTransposeInstruction(

View File

@ -447,7 +447,7 @@ class HloSortInstruction : public HloInstruction {
public: public:
explicit HloSortInstruction(const Shape& shape, int64 dimension, explicit HloSortInstruction(const Shape& shape, int64 dimension,
absl::Span<HloInstruction* const> operands, absl::Span<HloInstruction* const> operands,
HloComputation* compare); HloComputation* compare, bool is_stable);
// Returns the dimension sizes or numbers associated with this instruction. // Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; } const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; } int64 dimensions(int64 index) const override { return dimensions()[index]; }
@ -460,6 +460,7 @@ class HloSortInstruction : public HloInstruction {
HloInstruction* mutable_keys() { return mutable_operand(0); } HloInstruction* mutable_keys() { return mutable_operand(0); }
// Returns the number of value operands. // Returns the number of value operands.
int64 values_count() const { return operand_count() - 1; } int64 values_count() const { return operand_count() - 1; }
bool is_stable() const { return is_stable_; }
private: private:
std::vector<string> ExtraAttributesToStringImpl( std::vector<string> ExtraAttributesToStringImpl(
@ -474,6 +475,7 @@ class HloSortInstruction : public HloInstruction {
HloCloneContext* context) const override; HloCloneContext* context) const override;
std::vector<int64> dimensions_; std::vector<int64> dimensions_;
bool is_stable_;
}; };
class HloTransposeInstruction : public HloInstruction { class HloTransposeInstruction : public HloInstruction {

View File

@ -895,6 +895,8 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
optional<std::vector<int64>> dimensions; optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions}; &dimensions};
optional<bool> is_stable = false;
attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable};
optional<HloComputation*> to_apply; optional<HloComputation*> to_apply;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply}; &to_apply};
@ -902,8 +904,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
dimensions->size() != 1) { dimensions->size() != 1) {
return false; return false;
} }
instruction = builder->AddInstruction(HloInstruction::CreateSort( instruction = builder->AddInstruction(
shape, dimensions->at(0), operands, to_apply.value())); HloInstruction::CreateSort(shape, dimensions->at(0), operands,
to_apply.value(), is_stable.value()));
break; break;
} }
case HloOpcode::kTuple: { case HloOpcode::kTuple: {

View File

@ -1145,6 +1145,24 @@ ENTRY Sort {
ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare
} }
)"
},
// Sort (Key) is_stable=true
{
"SortKeyStable",
R"(HloModule sort
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
}
ENTRY Sort {
x = f32[1024]{0} parameter(0)
ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare
}
)" )"
}, },
// Conditional // Conditional

View File

@ -176,7 +176,7 @@ StatusOr<Literal> HloRunner::Execute(
TransferLiteralsToDevice(arguments)); TransferLiteralsToDevice(arguments));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
ExecuteWithDeviceBuffers( ExecuteWithDeviceBuffers(
/*module=*/std::move(executable), /*executable=*/executable.get(),
/*arguments=*/argument_buffers, /*arguments=*/argument_buffers,
/*profile=*/profile)); /*profile=*/profile));
return TransferLiteralFromDevice(result); return TransferLiteralFromDevice(result);
@ -235,7 +235,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
} }
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers( StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
std::unique_ptr<Executable> executable, Executable* executable,
const absl::Span<const ShapedBuffer* const> arguments, const absl::Span<const ShapedBuffer* const> arguments,
ExecutionProfile* profile) { ExecutionProfile* profile) {
// Get service run options. // Get service run options.
@ -254,7 +254,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
} }
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers( StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
std::unique_ptr<Executable> executable, Executable* executable,
const absl::Span<const ScopedShapedBuffer> arguments, const absl::Span<const ScopedShapedBuffer> arguments,
ExecutionProfile* profile) { ExecutionProfile* profile) {
std::vector<const ShapedBuffer*> argument_pointers; std::vector<const ShapedBuffer*> argument_pointers;

View File

@ -144,13 +144,16 @@ class HloRunner {
const absl::Span<const ScopedShapedBuffer> arguments, const absl::Span<const ScopedShapedBuffer> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
// In the following two calls, "executable" is not a unique_ptr to allow
// reuse of the Executable. This call may update the profile information in
// *executable.
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers( StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
std::unique_ptr<Executable> executable, Executable* executable,
const absl::Span<const ShapedBuffer* const> arguments, const absl::Span<const ShapedBuffer* const> arguments,
ExecutionProfile* profile = nullptr); ExecutionProfile* profile = nullptr);
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers( StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
std::unique_ptr<Executable> executable, Executable* executable,
const absl::Span<const ScopedShapedBuffer> arguments, const absl::Span<const ScopedShapedBuffer> arguments,
ExecutionProfile* profile = nullptr); ExecutionProfile* profile = nullptr);

View File

@ -36,6 +36,9 @@ StatusOr<bool> OpExpanderPass::Run(HloModule* module) {
for (HloInstruction* inst : matching_instructions) { for (HloInstruction* inst : matching_instructions) {
TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root,
ExpandInstruction(inst)); ExpandInstruction(inst));
if (expanded_root == nullptr) {
continue;
}
TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root));
} }

View File

@ -33,7 +33,9 @@ class OpExpanderPass : public HloModulePass {
// Returns `true` if `instruction` should be expanded by this pass. // Returns `true` if `instruction` should be expanded by this pass.
virtual bool InstructionMatchesPattern(HloInstruction* instruction) = 0; virtual bool InstructionMatchesPattern(HloInstruction* instruction) = 0;
// Returns a replacement for `instruction`. // Returns a replacement for `instruction`, or nullptr if no replacement is
// neeeded (e.g. only the to_apply subcomputation of the instruction was
// modified).
virtual StatusOr<HloInstruction*> ExpandInstruction( virtual StatusOr<HloInstruction*> ExpandInstruction(
HloInstruction* instruction) = 0; HloInstruction* instruction) = 0;
}; };

View File

@ -0,0 +1,204 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/stable_sort_expander.h"
#include <limits>
#include <memory>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
// Looks for a iota operand that can be used as tie breaker in the computation.
// If no matching iota operand is found, a iota operand is added to Sort. The
// comparison computation is adjusted to break ties using the values from the
// iota operand.
StatusOr<HloInstruction*> StableSortExpander::ExpandInstruction(
HloInstruction* instruction) {
auto* sort = Cast<HloSortInstruction>(instruction);
HloComputation* computation = sort->parent();
HloInstruction* expanded_sort = nullptr;
absl::flat_hash_set<int64> used_indices;
int64 iota_index = -1;
for (const HloInstruction* operand : sort->operands()) {
// We can only use the iota operand if it has an iota dimension which is the
// same as the dimension to sort. Also it should have an integral type that
// is large enough for the number of elements in the sort dimension. For
// now, we only allow S32, because we expect to find a S32 iota operand for
// all Sort ops which are created by TopK.
// TODO(b/122298745): Also support other types.
if (operand->opcode() == HloOpcode::kIota &&
Cast<HloIotaInstruction>(operand)->iota_dimension() ==
sort->sort_dimension() &&
operand->shape().element_type() == S32) {
iota_index = sort->operand_index(operand);
break;
}
}
// If there is currently no iota operand which we could use for making the
// sort stable, we will have to add a new such operand.
if (iota_index == -1) {
Shape iota_shape = sort->operand(0)->shape();
// We might need to use S64 if the number of elements in the sort dimension
// is bigger than 2^31 - 1.
// TODO(b/122298745): Handle Sort ops where S32 is too small for the number
// of elements in the sort dimension.
if (iota_shape.dimensions(sort->sort_dimension()) >
std::numeric_limits<int32>::max()) {
return Unimplemented(
"Stable sorting of more than 2^31-1 elements is not implemented");
}
iota_shape.set_element_type(S32);
auto iota = computation->AddInstruction(
HloInstruction::CreateIota(iota_shape, sort->sort_dimension()));
// Create a new comparator.
auto comparator = sort->to_apply();
absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements;
std::vector<std::unique_ptr<HloInstruction>> extra_parameters;
std::vector<HloInstruction*> extra_parameter_ptrs;
Shape scalar_shape = ShapeUtil::MakeShape(S32, {});
extra_parameters.push_back(HloInstruction::CreateParameter(
sort->operand_count() * 2, scalar_shape,
absl::StrCat("p.", sort->operand_count(), ".lhs")));
extra_parameter_ptrs.push_back(extra_parameters.back().get());
extra_parameters.push_back(HloInstruction::CreateParameter(
sort->operand_count() * 2 + 1, scalar_shape,
absl::StrCat("p.", sort->operand_count(), ".rhs")));
extra_parameter_ptrs.push_back(extra_parameters.back().get());
sort->set_to_apply(sort->GetModule()->AddEmbeddedComputation(
comparator->CloneWithReplacements(std::move(replacements),
extra_parameter_ptrs)));
// Replace the original sort op.
std::vector<HloInstruction*> new_operands(sort->operands().begin(),
sort->operands().end());
new_operands.push_back(iota);
std::vector<Shape> new_shapes = sort->operand_count() == 1
? std::vector<Shape>{sort->shape()}
: sort->shape().tuple_shapes();
new_shapes.push_back(iota_shape);
Shape new_sort_shape = ShapeUtil::MakeTupleShape(new_shapes);
HloInstruction* new_sort = computation->AddInstruction(
sort->CloneWithNewOperands(new_sort_shape, new_operands));
// Add a "wrapper" around the new sort op to make sure we have the same
// shape as before. For the rank 1 case, we only need a GetTupleElement,
// otherwise we create a Tuple consisting of GetTupleElements of the new
// sort.
std::vector<HloInstruction*> tuple_elements;
tuple_elements.reserve(sort->operand_count());
for (int64 i = 0; i < sort->operand_count(); ++i) {
tuple_elements.push_back(
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
sort->operand(i)->shape(), new_sort, i)));
}
expanded_sort = tuple_elements[0];
if (tuple_elements.size() > 1) {
expanded_sort = computation->AddInstruction(
HloInstruction::CreateTuple(tuple_elements));
}
sort = Cast<HloSortInstruction>(new_sort);
iota_index = sort->operand_count() - 1;
}
// Modify the computation to break ties using the iota operand.
auto comparator = sort->to_apply();
std::vector<HloInstruction*> instructions_postorder =
comparator->MakeInstructionPostOrder();
absl::flat_hash_map<HloInstruction*, HloInstruction*> replacements;
// Look up instr in the replacements map, and return either the replacement,
// or instr, if the replacement isn't present.
auto replace = [&](HloInstruction* instr) {
auto it = replacements.find(instr);
if (it == replacements.end()) {
return instr;
}
return it->second;
};
HloInstruction* old_root = comparator->root_instruction();
// The comparison computation gets 2 * n parameters (n being the number of
// operands of Sort), where parameters 2 * i and 2 * i + 1 correspond to two
// different scalars of operand i of Sort which are to be compared. The
// comparison computation should induce a strict weak order, so if
// to_apply(p1.lhs, p1.rhs, ..., pn.lhs, pn.rhs) is equal to
// to_apply(p1.rhs, p1.lhs, ..., pn.rhs, pn.lhs), we can conclude that the
// values to be compared are equivalent, and perform a tie-breaker comparison.
//
// We clone each instruction with at least one operand, but use as new
// operands of the instruction the replacements of the original operands.
// Parameter 2 * i is replaced by parameter 2 * i + 1 and vice versa. This
// should make sure that the cloned root instruction gives the result of the
// comparison computation when being called with each scalar pair reversed.
// parameters corresponding to the iota operand.
for (int64 i = 0; i < comparator->num_parameters(); ++i) {
replacements[comparator->parameter_instruction(i)] =
comparator->parameter_instruction(i ^ 1);
}
HloInstruction* cloned_root = nullptr;
for (HloInstruction* inst : instructions_postorder) {
if (inst->operand_count() == 0) {
continue;
}
std::vector<HloInstruction*> new_operands;
new_operands.reserve(inst->operand_count());
for (HloInstruction* operand : inst->operands()) {
new_operands.push_back(replace(operand));
}
auto new_instruction =
inst->CloneWithNewOperands(inst->shape(), new_operands);
replacements[inst] = new_instruction.get();
if (inst == old_root) {
cloned_root = new_instruction.get();
}
comparator->AddInstruction(std::move(new_instruction));
}
CHECK_NE(cloned_root, nullptr);
Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
HloInstruction* same =
comparator->AddInstruction(HloInstruction::CreateBinary(
scalar_pred, HloOpcode::kEq, old_root, cloned_root));
HloInstruction* tie_breaker =
comparator->AddInstruction(HloInstruction::CreateBinary(
scalar_pred, HloOpcode::kLt,
comparator->parameter_instruction(2 * iota_index),
comparator->parameter_instruction(2 * iota_index + 1)));
HloInstruction* new_root =
comparator->AddInstruction(HloInstruction::CreateTernary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker,
old_root));
comparator->set_root_instruction(new_root);
return expanded_sort;
}
bool StableSortExpander::InstructionMatchesPattern(
HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kSort &&
Cast<HloSortInstruction>(instruction)->is_stable();
}
} // namespace xla

View File

@ -0,0 +1,42 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
// HLO pass which expands Sort ops that have the is_stable field set to true
// into equivalent Sort ops which guarantee stable sorting without relying on
// the is_stable field.
class StableSortExpander : public OpExpanderPass {
public:
absl::string_view name() const override { return "stable-sort-expander"; }
private:
bool InstructionMatchesPattern(HloInstruction* instruction) override;
StatusOr<HloInstruction*> ExpandInstruction(
HloInstruction* instruction) override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_

View File

@ -0,0 +1,358 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/stable_sort_expander.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
namespace m = match;
using StableSortExpanderTest = HloTestBase;
// Checks whether 'a' and 'b' are roots of equivalent computations, except that
// parameters 2 * i and 2 * i + 1 are switched.
bool IsSameComputationExceptParams(const HloInstruction* a,
const HloInstruction* b) {
if (a->opcode() != b->opcode() || a->operand_count() != b->operand_count()) {
return false;
}
if (a->opcode() == HloOpcode::kParameter) {
// Check that parameters were switched.
return a->parameter_number() == (b->parameter_number() ^ 1);
}
// If the operation has no operands, it should actually be the same.
if (a->operand_count() == 0) {
return a == b;
}
// Otherwise recursively compare all operands.
for (int64 i = 0; i < a->operand_count(); ++i) {
if (!IsSameComputationExceptParams(a->operand(i), b->operand(i))) {
return false;
}
}
return true;
}
// Check that the comparison computation has been modified to add a tie breaker
// using 'iota_parameter'.
void CheckComputationHasTieBreaker(const HloInstruction* root,
int64 iota_parameter) {
// With the tie breaker, the root instruction should be
// Select(Eq(Comp(), CompReverse()), Lt(), Comp())
// with Comp() being the original comparison function, and CompReverse() being
// the copied comparison function where the parameters are reversed. Lt() is
// the tie breaker comparison using the Iota operand.
ASSERT_EQ(root->opcode(), HloOpcode::kSelect);
ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kEq);
// Check that the tie breaker instruction is correct.
EXPECT_THAT(root->operand(1),
GmockMatch(m::Lt(m::Parameter(iota_parameter * 2),
m::Parameter(iota_parameter * 2 + 1))));
EXPECT_EQ(root->operand(2), root->operand(0)->operand(0));
// Check that Comp() and CompReverse() are equivalent except that
// CompReverse() has reversed parameters.
EXPECT_TRUE(IsSameComputationExceptParams(root->operand(0)->operand(0),
root->operand(0)->operand(1)));
}
TEST_F(StableSortExpanderTest, StabilizeSortReuseIotaOperand) {
const char* hlo_string = R"(
HloModule permutation_sort
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
}
ENTRY sort_computation {
keys = f32[64,8732]{1,0} parameter(0)
values = s32[64,8732]{1,0} iota(), iota_dimension=1
sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
dimensions={1}, to_apply=compare, is_stable=true
ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
StableSortExpander stabilizer;
EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
m::Sort(m::Parameter(0), m::Iota()), 0)));
CheckComputationHasTieBreaker(
root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1);
}
TEST_F(StableSortExpanderTest,
StabilizeSortReuseIotaOperandComplicatedComparison) {
const char* hlo_string = R"(
HloModule permutation_sort
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
max = u32[] constant(2147483647)
zero = s32[] constant(0)
lhs.signed = s32[] bitcast-convert(p.0.lhs)
lhs.unsigned = u32[] bitcast-convert(p.0.lhs)
lhs.flipped = u32[] subtract(max, lhs.unsigned)
lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped)
lhs.is_negative = pred[] less-than(lhs.flipped.signed, zero)
lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed)
rhs.signed = s32[] bitcast-convert(p.0.rhs)
rhs.unsigned = u32[] bitcast-convert(p.0.rhs)
rhs.flipped = u32[] subtract(max, rhs.unsigned)
rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped)
rhs.is_negative = pred[] less-than(rhs.flipped.signed, zero)
rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed)
ROOT lt = pred[] less-than(lhs.converted, rhs.converted)
}
ENTRY sort_computation {
keys = f32[64,8732]{1,0} parameter(0)
values = s32[64,8732]{1,0} iota(), iota_dimension=1
sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
dimensions={1}, to_apply=compare, is_stable=true
ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
StableSortExpander stabilizer;
EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
m::Sort(m::Parameter(0), m::Iota()), 0)));
CheckComputationHasTieBreaker(
root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1);
}
TEST_F(StableSortExpanderTest, StabilizeSortAddIotaOperandAndChangeRoot) {
const char* hlo_string = R"(
HloModule permutation_sort
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
}
ENTRY sort_computation {
keys = f32[64,8732]{1,0} parameter(0)
values = s32[64,8732]{1,0} parameter(1)
ROOT sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
dimensions={1}, to_apply=compare, is_stable=true
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
StableSortExpander stabilizer;
EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root, GmockMatch(m::Tuple(
m::GetTupleElement(
m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 0),
m::GetTupleElement(
m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 1))));
CheckComputationHasTieBreaker(
root->operand(0)->operand(0)->to_apply()->root_instruction(),
/*iota_parameter=*/2);
}
TEST_F(StableSortExpanderTest, HonorIsStableFlag) {
const char* hlo_string = R"(
HloModule permutation_sort
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
}
ENTRY sort_computation {
keys = f32[64,8732]{1,0} parameter(0)
values = s32[64,8732]{1,0} iota(), iota_dimension=1
sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
dimensions={1}, to_apply=compare, is_stable=false
ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
StableSortExpander stabilizer;
EXPECT_FALSE(stabilizer.Run(module.get()).ValueOrDie());
}
TEST_F(StableSortExpanderTest,
StabilizeSortDontReuseIotaOperandWrongDimension) {
const char* hlo_string = R"(
HloModule permutation_sort
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
}
ENTRY sort_computation {
keys = f32[64,8732]{1,0} parameter(0)
values = s32[64,8732]{1,0} iota(), iota_dimension=0
sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
dimensions={1}, to_apply=compare, is_stable=true
ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
StableSortExpander stabilizer;
EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
// Simplify away the "wrapper" tuple around the new sort.
AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions(
[](const Shape&, const Shape&) { return false; }));
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0)));
CheckComputationHasTieBreaker(
root->operand(0)->to_apply()->root_instruction(),
/*iota_parameter=*/2);
}
TEST_F(StableSortExpanderTest, StabilizeSortDontReuseIotaOperandWrongType) {
const char* hlo_string = R"(
HloModule permutation_sort
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
p.1.lhs = f32[] parameter(2)
p.1.rhs = f32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
}
ENTRY sort_computation {
keys = f32[64,8732]{1,0} parameter(0)
values = f32[64,8732]{1,0} iota(), iota_dimension=1
sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values),
dimensions={1}, to_apply=compare, is_stable=true
ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
StableSortExpander stabilizer;
EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
// Simplify away the "wrapper" tuple around the new sort.
AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions(
[](const Shape&, const Shape&) { return false; }));
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0)));
CheckComputationHasTieBreaker(
root->operand(0)->to_apply()->root_instruction(),
/*iota_parameter=*/2);
}
TEST_F(StableSortExpanderTest, StabilizeSortR1) {
const char* hlo_string = R"(
HloModule permutation_sort
compare {
p.0.lhs = s32[] parameter(0)
p.0.rhs = s32[] parameter(1)
mask = s32[] constant(65535)
lhs = s32[] and(p.0.lhs, mask)
rhs = s32[] and(p.0.rhs, mask)
ROOT lt = pred[] less-than(lhs, rhs)
}
ENTRY sort_computation {
keys = s32[64,8732]{1,0} parameter(0)
ROOT sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare,
is_stable=true
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
StableSortExpander stabilizer;
EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::GetTupleElement(
m::Sort(m::Parameter(0), m::Iota()), 0)));
CheckComputationHasTieBreaker(
root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1);
}
TEST_F(StableSortExpanderTest, StabilizeSortR1NoRoot) {
const char* hlo_string = R"(
HloModule permutation_sort
compare {
p.0.lhs = s32[] parameter(0)
p.0.rhs = s32[] parameter(1)
mask = s32[] constant(65535)
lhs = s32[] and(p.0.lhs, mask)
rhs = s32[] and(p.0.rhs, mask)
ROOT lt = pred[] less-than(lhs, rhs)
}
ENTRY sort_computation {
keys = s32[64,8732]{1,0} parameter(0)
sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare,
is_stable=true
ROOT neg = s32[64,8732]{1,0} negate(sort)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
StableSortExpander stabilizer;
EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Negate(m::GetTupleElement(
m::Sort(m::Parameter(0), m::Iota()), 0))));
CheckComputationHasTieBreaker(
root->operand(0)->operand(0)->to_apply()->root_instruction(),
/*iota_parameter=*/1);
}
} // namespace
} // namespace xla

View File

@ -1072,7 +1072,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
auto keys = builder.AddInstruction( auto keys = builder.AddInstruction(
HloInstruction::CreateParameter(0, keys_shape, "keys")); HloInstruction::CreateParameter(0, keys_shape, "keys"));
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto* sort, MakeSortHlo(keys_shape, {keys}, 0, &builder, module_.get())); auto* sort, MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false,
&builder, module_.get()));
computation_ = module_->AddEntryComputation(builder.Build()); computation_ = module_->AddEntryComputation(builder.Build());
RunAnalysis(); RunAnalysis();
@ -1094,7 +1095,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto* sort, auto* sort,
MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}), MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
{keys, values}, 0, &builder, module_.get())); {keys, values}, 0, /*is_stable=*/false, &builder,
module_.get()));
computation_ = module_->AddEntryComputation(builder.Build()); computation_ = module_->AddEntryComputation(builder.Build());
RunAnalysis(); RunAnalysis();

View File

@ -1146,7 +1146,7 @@ xla_test(
xla_test( xla_test(
name = "reduce_test", name = "reduce_test",
srcs = ["reduce_test.cc"], srcs = ["reduce_test.cc"],
shard_count = 40, shard_count = 31,
tags = [ tags = [
"optonly", "optonly",
], ],

View File

@ -1188,6 +1188,8 @@ std::vector<EinsumParamType> GetEinsumTestCases() {
p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"}, p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"},
p{v{5, 6}, v{6, 7}, "ab,cd->dcba"}, p{v{5, 6}, v{6, 7}, "ab,cd->dcba"},
p{v{6}, v{6, 7}, "b,bc->c"}, p{v{6}, v{6, 7}, "b,bc->c"},
p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc->ab"},
p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba->ca"},
p{v{77}, v{77}, "a,a->a"}, p{v{77}, v{77}, "a,a->a"},
p{v{77}, v{77, 55}, "a,ab->ba"}, p{v{77}, v{77, 55}, "a,ab->ba"},
p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"}, p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"},
@ -1265,5 +1267,51 @@ ENTRY %test {
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
} }
XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_1) {
// Tests for a caching bug in the XLA CPU backend.
absl::string_view hlo_string =
R"(
HloModule CpuTiledDotEmitterCachingBug
ENTRY main {
lhs = f32[20,40] parameter(0)
rhs_0 = f32[40,1] parameter(2)
rhs_1 = f32[1,40] parameter(1)
dot_0 = f32[20,1] dot(lhs, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
dot_1 = f32[20,1] dot(lhs, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
ROOT result = f32[20,1] divide(dot_0, dot_1)
}
)";
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
}
XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_2) {
// Tests for a caching bug in the XLA CPU backend.
absl::string_view hlo_string =
R"(
HloModule CpuTiledDotEmitterCachingBug
ENTRY main {
lhs_0 = f32[20,40] parameter(0)
rhs_0 = f32[40,1] parameter(1)
lhs_1 = f32[1,40] parameter(2)
rhs_1 = f32[20,40] parameter(3)
dot_0 = f32[20,1] dot(lhs_0, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
dot_1 = f32[1,20] dot(lhs_1, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
dot_0_reshaped = f32[20] reshape(dot_0)
dot_1_reshaped = f32[20] reshape(dot_1)
ROOT result = f32[20] divide(dot_0_reshaped, dot_1_reshaped)
}
)";
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -205,6 +205,17 @@ Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
} }
StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
int64 num_replicas) {
HloRunner::ReplicatedExecuteOptions options;
options.num_replicas = num_replicas;
for (auto argument : arguments) {
options.arguments.push_back(argument);
}
return test_runner_.ExecuteReplicated(std::move(module), options);
}
StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule( StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
const HloModule& test_module, const HloModule& test_module,
const std::function<void(HloModule*)>& reference_preprocessor) { const std::function<void(HloModule*)>& reference_preprocessor) {

View File

@ -173,6 +173,11 @@ class HloTestBase : public ::testing::Test {
Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module, Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module,
absl::Span<Literal* const> arguments); absl::Span<Literal* const> arguments);
// Executes the given module on multiple replicas.
StatusOr<std::vector<Literal>> ExecuteReplicated(
std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
int64 num_replicas);
// Executes the given hlo module on two backends and compares results. // Executes the given hlo module on two backends and compares results.
// //
// 'arguments': the input of the hlo module. // 'arguments': the input of the hlo module.

View File

@ -30,6 +30,11 @@ def xla_proto_library(name, srcs = [], deps = [], visibility = None, testonly =
**kwargs **kwargs
) )
def xla_py_proto_library(**kwargs):
# Note: we don't currently define a proto library target for Python in OSS.
_ignore = kwargs
pass
def xla_py_grpc_library(**kwargs): def xla_py_grpc_library(**kwargs):
# Note: we don't currently define any special targets for Python GRPC in OSS. # Note: we don't currently define any special targets for Python GRPC in OSS.
_ignore = kwargs _ignore = kwargs

View File

@ -122,6 +122,17 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
.HostMemory("literal"), .HostMemory("literal"),
XRTReadLiteralOp<true, XRTGenericDeviceAccessor>); XRTReadLiteralOp<true, XRTGenericDeviceAccessor>);
REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor")
.Device(DEVICE_XLA_GPU)
.HostMemory("handles")
.HostMemory("tensors"),
XRTReadToTensorOp<XRTGenericDeviceAccessor>);
REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor")
.Device(DEVICE_XLA_CPU)
.HostMemory("handles")
.HostMemory("tensors"),
XRTReadToTensorOp<XRTGenericDeviceAccessor>);
REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
.Device(DEVICE_XLA_GPU) .Device(DEVICE_XLA_GPU)
.HostMemory("handle"), .HostMemory("handle"),

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal.h"
@ -40,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
@ -215,27 +217,29 @@ class XRTAllocateFromTensorOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple)); OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple));
std::vector<int64> minor_to_major;
if (ctx->HasAttr("layouts")) { if (ctx->HasAttr("layouts")) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major));
} }
OP_REQUIRES( OP_REQUIRES(
ctx, tf_shapes_.size() == dtypes_.size(), ctx, tf_shapes_.size() == dtypes_.size(),
errors::InvalidArgument("shapes and dtypes must be the same length")); errors::InvalidArgument("shapes and dtypes must be the same length"));
std::vector<xla::Shape> xla_shapes; std::vector<xla::Shape> xla_shapes;
xla_shapes.reserve(tf_shapes_.size());
for (int i = 0; i < tf_shapes_.size(); i++) { for (int i = 0; i < tf_shapes_.size(); i++) {
xla::Shape xla_shape; xla::Shape xla_shape;
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape)); ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape));
xla_shapes.push_back(xla_shape); xla_shapes.push_back(std::move(xla_shape));
} }
if (xla_shapes.size() > 1 || make_tuple) { if (xla_shapes.size() > 1 || make_tuple) {
shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes); shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes);
} else { } else {
shape_.Swap(&xla_shapes.front()); shape_.Swap(&xla_shapes.front());
} }
if (!minor_to_major_.empty()) { if (!minor_to_major.empty()) {
xla::Shape shape_with_layouts; xla::Shape shape_with_layouts;
OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major_, OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major,
/*layout_func=*/nullptr, /*layout_func=*/nullptr,
&shape_with_layouts)); &shape_with_layouts));
shape_.Swap(&shape_with_layouts); shape_.Swap(&shape_with_layouts);
@ -304,7 +308,6 @@ class XRTAllocateFromTensorOp : public OpKernel {
private: private:
std::vector<TensorShape> tf_shapes_; std::vector<TensorShape> tf_shapes_;
DataTypeVector dtypes_; DataTypeVector dtypes_;
std::vector<int64> minor_to_major_;
xla::Shape shape_; xla::Shape shape_;
}; };
@ -487,7 +490,7 @@ class XRTReadLiteralOp : public OpKernel {
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
ctx, allocation->device_ordinal(), &device_ref)); ctx, allocation->device_ordinal(), &device_ref));
xla::Literal literal; xla::Literal literal(allocation->on_host_shape());
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, allocation->ToLiteral(device_ref.backend(), ctx, allocation->ToLiteral(device_ref.backend(),
device_ref.device_ordinal(), &literal)); device_ref.device_ordinal(), &literal));
@ -499,6 +502,96 @@ class XRTReadLiteralOp : public OpKernel {
} }
}; };
// Op that reads a device-resident tuple to host memory and returns it as a
// literal.
template <class DeviceAccessor>
class XRTReadToTensorOp : public OpKernel {
public:
explicit XRTReadToTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("release_handles", &discard_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
}
~XRTReadToTensorOp() override = default;
XRTReadToTensorOp(const XRTReadToTensorOp&) = delete;
XRTReadToTensorOp& operator=(const XRTReadToTensorOp&) = delete;
void Compute(OpKernelContext* ctx) override {
VLOG(1) << "XRTReadToTensorOp::Compute";
const Tensor& handle_tensor = ctx->input(0);
// TODO(phawkins,dlibenzi): accept multiple handles (i.e., vectors, not
// just scalars.)
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
errors::Internal("computation input should be an int64 scalar"));
int64 allocation_handle = handle_tensor.scalar<int64>()();
ResourceMgr* rm;
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
XRTTupleAllocation* allocation;
OP_REQUIRES_OK(
ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
core::ScopedUnref allocation_unref(allocation);
if (discard_) {
VLOG(2) << "Releasing handle " << allocation_handle;
OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
rm, allocation_handle));
}
// We are guaranteed that the underlying device object won't be deleted out
// from under us, while the ScopedRef is live.
class DeviceAccessor::ScopedRef device_ref;
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
ctx, allocation->device_ordinal(), &device_ref));
xla::Shape shape = allocation->on_host_shape();
int output = 0;
Status status = xla::ShapeUtil::ForEachMutableSubshapeWithStatus(
&shape,
[&](xla::Shape* subshape, const xla::ShapeIndex& index) -> Status {
if (subshape->IsTuple()) return Status::OK();
xla::PrimitiveType xla_type;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(
ctx->expected_output_dtype(output), &xla_type));
if (xla_type != subshape->element_type()) {
return errors::InvalidArgument(
"Type mismatch between buffer type (", subshape->ToString(),
") and tensor type (",
DataTypeString(ctx->expected_output_dtype(output)),
") for output tensor ", output);
}
TensorShape output_shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(*subshape, &output_shape));
Tensor* output_tensor;
TF_RETURN_IF_ERROR(
ctx->allocate_output(output, output_shape, &output_tensor));
XRTTupleAllocation* sub;
TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
allocation, index, &sub, /*alias_parent_allocation=*/true));
core::ScopedUnref sub_unref(sub);
xla::MutableBorrowingLiteral literal;
TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral(
xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor,
&literal));
TF_RETURN_IF_ERROR(sub->ToLiteral(
device_ref.backend(), device_ref.device_ordinal(), &literal));
++output;
return Status::OK();
});
OP_REQUIRES_OK(ctx, status);
}
bool discard_;
DataTypeVector dtypes_;
};
// Op that writes a new literal value into device-resident memory. // Op that writes a new literal value into device-resident memory.
template <class DeviceAccessor> template <class DeviceAccessor>
class XRTWriteLiteralOp : public OpKernel { class XRTWriteLiteralOp : public OpKernel {

View File

@ -151,6 +151,27 @@ releases the handle.
'literal' is a serialized xla::LiteralProto proto. 'literal' is a serialized xla::LiteralProto proto.
)"); )");
REGISTER_OP("XRTReadToTensor")
.Input("handles: int64")
.Attr("release_handles: bool = False")
.Attr("dtypes: list(type)")
.Output("tensors: dtypes")
.SetShapeFn(tensorflow::shape_inference::UnknownShape)
.Doc(
R"(
Copies allocated values from device memory and returns them as zero or more
Tensors. If a handle refers to a non-tuple buffer, a single tensor is returned.
In general, the tensors returned for a handle correspond to an in-order traversal
of a the tuple-tree value referenced by the handle.
'handles' contains ids returned from Ops that produced on-device allocations.
At present, only a single (scalar) handle is supported.
'dtypes' are the expected types for each `Tensor` to be returned. If the
expected and actual tensor types do not match, an error is returned.
'release_handles': if True, `handles` are released.
'tensors' are the output Tensors.
)");
REGISTER_OP("XRTReleaseAllocationHandle") REGISTER_OP("XRTReleaseAllocationHandle")
.Input("handle: int64") .Input("handle: int64")
.SetShapeFn(tensorflow::shape_inference::NoOutputs) .SetShapeFn(tensorflow::shape_inference::NoOutputs)

View File

@ -220,7 +220,7 @@ XRTTupleAllocation::~XRTTupleAllocation() {
} }
Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
xla::Literal* literal) { xla::MutableLiteralBase* literal) {
auto transfer_manager = backend->transfer_manager(); auto transfer_manager = backend->transfer_manager();
TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
@ -234,9 +234,8 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
" has been released"); " has been released");
} }
} }
TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice( return transfer_manager->TransferLiteralFromDevice(stream.get(),
stream.get(), shaped_buffer)); shaped_buffer, *literal);
return Status::OK();
} }
Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend, Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend,

View File

@ -147,7 +147,7 @@ class XRTTupleAllocation : public ResourceBase {
// Copies the allocation from device to host and returns it in literal. // Copies the allocation from device to host and returns it in literal.
Status ToLiteral(xla::Backend* backend, int device_ordinal, Status ToLiteral(xla::Backend* backend, int device_ordinal,
xla::Literal* literal); xla::MutableLiteralBase* literal);
// Write a new literal value to the allocation. // Write a new literal value to the allocation.
Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal);

View File

@ -218,7 +218,6 @@ cc_library(
"//tensorflow/contrib/tensor_forest:stats_ops_op_lib", "//tensorflow/contrib/tensor_forest:stats_ops_op_lib",
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
"//tensorflow/contrib/text:all_ops", "//tensorflow/contrib/text:all_ops",
"//tensorflow/contrib/tpu:all_ops",
] + select({ ] + select({
"//tensorflow:android": [], "//tensorflow:android": [],
"//tensorflow:ios": [], "//tensorflow:ios": [],

Some files were not shown because too many files have changed in this diff Show More