diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index a6eb4755f32..ddcacfcbe2d 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -26,14 +26,28 @@ import sys as _sys # API IMPORTS PLACEHOLDER +# Make sure directory containing top level submodules is in +# the __path__ so that "from tensorflow.foo import bar" works. +# We're using bitwise, but there's nothing special about that. +_API_MODULE = bitwise # pylint: disable=undefined-variable +_current_module = _sys.modules[__name__] +_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) +if not hasattr(_current_module, '__path__'): + __path__ = [_tf_api_dir] +elif _tf_api_dir not in __path__: + __path__.append(_tf_api_dir) + # pylint: disable=g-bad-import-order from tensorflow.python.tools import component_api_helper as _component_api_helper +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorboard.summary._tf.summary'), + error_msg="Limited tf.summary API due to missing TensorBoard installation") _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=( 'tensorflow_estimator.python.estimator.api._v2.estimator')) -_current_module = _sys.modules[__name__] if not hasattr(_current_module, 'estimator'): _component_api_helper.package_hook( parent_package_str=__name__, @@ -42,14 +56,6 @@ if not hasattr(_current_module, 'estimator'): _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=('tensorflow.python.keras.api._v2.keras')) -# Make sure directory containing top level submodules is in -# the __path__ so that "from tensorflow.foo import bar" works. -# We're using bitwise, but there's nothing special about that. -_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable -if not hasattr(_current_module, '__path__'): - __path__ = [_tf_api_dir] -elif _tf_api_dir not in __path__: - __path__.append(_tf_api_dir) # Enable TF2 behaviors from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index eeca8f0d566..5eb25a81b7f 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -70,7 +70,7 @@ _API_MODULE = app # pylint: disable=undefined-variable # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. -_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) # pylint: disable=undefined-variable +_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) if not hasattr(_current_module, '__path__'): __path__ = [_tf_api_dir] elif _tf_api_dir not in __path__: diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index a09becc49b1..4c4d587fce0 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -150,6 +150,7 @@ cc_library_with_android_deps( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", ], ) @@ -586,6 +587,25 @@ tf_gen_op_wrappers_cc( pkg = "//tensorflow/core", ) +tf_gen_op_wrappers_cc( + name = "tpu_ops", + include_internal_ops = 1, + op_lib_names = [ + "tpu_configuration_ops", + "tpu_cross_replica_ops", + "tpu_embedding_ops", + "tpu_functional_ops", + "tpu_heartbeat_ops", + "tpu_host_compute_ops", + "tpu_infeed_ops", + "tpu_outfeed_ops", + "tpu_ordinal_selector_ops", + "tpu_replication_ops", + ], + pkg = "//tensorflow/core", + visibility = ["//tensorflow:internal"], +) + cc_library_with_android_deps( name = "cc_op_gen_main", srcs = [ diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 52345a376cc..dedd55f16af 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -81,6 +81,7 @@ cc_library( ] + if_not_mobile([ "//tensorflow/core:core_cpu", "//tensorflow/core:lib", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", ]) + if_android([ diff --git a/tensorflow/compat_template.__init__.py b/tensorflow/compat_template.__init__.py index 05fd9cd981f..2cf68c9cd83 100644 --- a/tensorflow/compat_template.__init__.py +++ b/tensorflow/compat_template.__init__.py @@ -22,11 +22,16 @@ import os as _os import sys as _sys # pylint: disable=g-bad-import-order -from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import # API IMPORTS PLACEHOLDER from tensorflow.python.tools import component_api_helper as _component_api_helper +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorboard.summary._tf.summary'), + error_msg=( + "Limited tf.compat.v2.summary API due to missing TensorBoard " + "installation")) _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=( diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 1f8ec09e19c..261519de347 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -307,22 +307,6 @@ REGISTER_OP("XlaHostCompute") .Attr("shapes: list(shape) >= 0") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); -REGISTER_OP("_XlaSendFromHost") - .Input("inputs: Tinputs") - .Input("dynamic_key: string") - .Attr("Tinputs: list(type) >= 0") - .Attr("key: string") - .Attr("device_ordinal: int") - .SetShapeFn(::tensorflow::shape_inference::UnknownShape); - -REGISTER_OP("_XlaRecvAtHost") - .Input("dynamic_key: string") - .Output("outputs: Toutputs") - .Attr("Toutputs: list(type) >= 0") - .Attr("key: string") - .Attr("device_ordinal: int") - .SetShapeFn(::tensorflow::shape_inference::UnknownShape); - REGISTER_OP("InputTest") .Output("o: float") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 109684be72a..f0c9d573451 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -200,7 +200,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, auto serialized = absl::make_unique(size); TF_RET_CHECK(SerializeToBufferDeterministic(gdef, serialized.get(), size)); uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size)); - LOG(INFO) << "Subgraph fingerprint:" << fingerprint; + VLOG(1) << "Subgraph fingerprint:" << fingerprint; call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); return Status::OK(); } diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 9b6ca4092c3..7c1e0daf0b7 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -250,6 +250,29 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "self_adjoint_eig_op_test", + size = "medium", + srcs = ["self_adjoint_eig_op_test.py"], + # TODO(kuny): remove it after b/124377352 is fixed. + disabled_backends = [ + "cpu", + "gpu", + "cpu_ondemand", + ], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:map_fn", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + "@absl_py//absl/testing:parameterized", + ], +) + tf_xla_py_test( name = "matrix_triangular_solve_op_test", size = "small", diff --git a/tensorflow/compiler/tests/self_adjoint_eig_op_test.py b/tensorflow/compiler/tests/self_adjoint_eig_op_test.py new file mode 100644 index 00000000000..cfb5c82b22e --- /dev/null +++ b/tensorflow/compiler/tests/self_adjoint_eig_op_test.py @@ -0,0 +1,62 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.self_adjoint_eig.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.platform import test + + +class SelfAdjointEigOpTest(xla_test.XLATestCase, parameterized.TestCase): + + def _test(self, dtype, shape): + np.random.seed(1) + x_np = np.random.uniform( + low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) + x_np = x_np + np.swapaxes(x_np, -1, -2) + n = shape[-1] + + e_np, _ = np.linalg.eigh(x_np) + with self.cached_session() as sess: + x_tf = array_ops.placeholder(dtype) + with self.test_scope(): + e, v = linalg_ops.self_adjoint_eig(x_tf) + e_val, v_val = sess.run([e, v], feed_dict={x_tf: x_np}) + + v_diff = np.matmul(v_val, np.swapaxes(v_val, -1, -2)) - np.eye(n) + self.assertAlmostEqual(np.mean(v_diff**2), 0.0, delta=1e-6) + self.assertAlmostEqual(np.mean((e_val - e_np)**2), 0.0, delta=1e-6) + + SIZES = [1, 2, 5, 10, 32] + DTYPES = [np.float32] + PARAMS = itertools.product(SIZES, DTYPES) + + @parameterized.parameters(*PARAMS) + def testSelfAdjointEig(self, n, dtype): + for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10): + self._test(dtype, batch_dims + (n, n)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index 47e0f384a4f..a380715301b 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -102,7 +102,7 @@ class ListOpsTest(xla_test.XLATestCase): _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Set the max number of elements"): - self.assertEqual(sess.run(e), 1.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15))) def testEmptyTensorListMax(self): with self.cached_session() as sess, self.test_scope(): @@ -136,6 +136,17 @@ class ListOpsTest(xla_test.XLATestCase): t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [3.0, 2.0]) + def testSetDoesNotUpdatePushIndex(self): + with self.cached_session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_shape=[], element_dtype=dtypes.float32, max_num_elements=2) + # SetItem should not change the push index. + l = list_ops.tensor_list_set_item(l, 1, 3.) + l = list_ops.tensor_list_push_back(l, 5.) + l = list_ops.tensor_list_push_back(l, 7.) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [5., 7.]) + def testGetSetReserved(self): with self.cached_session(), self.test_scope(): l = list_ops.tensor_list_reserve( @@ -146,6 +157,25 @@ class ListOpsTest(xla_test.XLATestCase): t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [3.0, 0.0]) + def testSetStackReservedUnknownElementShape(self): + with self.cached_session(), self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=None, num_elements=2) + l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0]) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]]) + + def testPushInEmptyListWithUnknownElementShape(self): + with self.cached_session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, element_shape=None, max_num_elements=2) + l = list_ops.tensor_list_push_back(l, [3.0, 4.0]) + # Pushing an element with a different shape should raise an error. + with self.assertRaisesRegexp(errors.InvalidArgumentError, "Shape"): + l = list_ops.tensor_list_push_back(l, 5.) + self.evaluate( + list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)) + def testGetSetReservedNonScalar(self): with self.cached_session() as sess, self.test_scope(): l = list_ops.tensor_list_reserve( diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 831e203f49d..f2e0eac2d99 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -72,6 +72,7 @@ class UnaryOpsTest(xla_test.XLATestCase): output = op(pinp) result = session.run(output, {pinp: inp}) if equality_test is None: + self.assertEqual(output.dtype, expected.dtype) self.assertAllCloseAccordingToType( result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03) else: @@ -260,7 +261,8 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( math_ops.log1p, np.array([[1e-14, 1e-15, 0.6]], dtype=dtype), - expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)), + expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], + dtype=dtype)).astype(dtype), rtol=1e-4, atol=1e-6) @@ -710,7 +712,7 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[2, -1]], dtype=dtype), - expected=np.array([[2, 1]], dtype=dtype)) + expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype)) self._assertOpOutputMatchesExpected( math_ops.negative, @@ -880,6 +882,17 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([[-1], [1], [4]], dtype=dtype), expected=np.int32(3)) + def testSizeWithInt64OutType(self): + + def size_op(x): + return array_ops.size_internal(x, optimize=False, out_type=np.int64) + + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + size_op, + np.array([[-1], [1], [4]], dtype=dtype), + expected=np.int64(3)) + def testUnpack(self): self._assertOpOutputMatchesExpected( array_ops.unstack, @@ -989,7 +1002,7 @@ class UnaryOpsTest(xla_test.XLATestCase): def _assertSoftplusMatchesExpected(self, features, dtype): features = np.array(features, dtype=dtype) zero = np.asarray(0).astype(dtype) - expected = np.logaddexp(zero, features) + expected = np.logaddexp(zero, features).astype(dtype) self._assertOpOutputMatchesExpected( nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index eca533095ac..63cad6a159c 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -171,13 +171,11 @@ tf_cuda_library( name = "trt_resources", srcs = [ "utils/trt_int8_calibrator.cc", - "utils/trt_resource_manager.cc", "utils/trt_resources.cc", ], hdrs = [ "utils/trt_int8_calibrator.h", "utils/trt_lru_cache.h", - "utils/trt_resource_manager.h", "utils/trt_resources.h", ], deps = [ @@ -266,7 +264,6 @@ tf_cuda_library( "//tensorflow/core:framework_lite", "//tensorflow/core:gpu_runtime", "//tensorflow/core:graph", - "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:devices", @@ -362,11 +359,12 @@ cc_library( ], ) -tf_cc_test( +tf_cuda_cc_test( name = "segment_test", size = "small", srcs = ["segment/segment_test.cc"], tags = [ + "no_cuda_on_cpu_tap", "no_windows", "nomac", ], @@ -432,7 +430,7 @@ cc_library( copts = tf_copts(), deps = [ "//tensorflow/core:framework", - "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", ], ) @@ -441,7 +439,7 @@ cc_library( srcs = ["utils/test_utils.cc"], hdrs = ["utils/test_utils.h"], deps = [ - "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", "@com_googlesource_code_re2//:re2", ], ) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index d6080c02d43..0b0cb0db8ef 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/compiler/tf2tensorrt/segment/segment.h" #include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" @@ -106,6 +105,7 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { "ExpandDims", "FusedBatchNorm", "FusedBatchNormV2", + "GatherV2", "Identity", "LeakyRelu", "Log", @@ -190,55 +190,6 @@ tensorflow::Status BuildNodeMap( } // namespace -// Function to get calibration from ResourceMgr and put them into nodedef. -tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph, - bool is_dyn_op) { - LOG(INFO) << "Starting Calib Conversion"; - *infer_graph = graph_def; - auto trt_rm = TRTResourceManager::instance(); - auto calib_rm = trt_rm->getManager("TRTCalibration"); - int num_nodes = infer_graph->node_size(); - if (!is_dyn_op) { - LOG(WARNING) << "Construction of static int8 engine is not implemented " - "yet!. Dynamic engine will be constructed"; - } - for (int i = 0; i < num_nodes; ++i) { - auto n = infer_graph->mutable_node(i); - if (n->op() == "TRTEngineOp") { - VLOG(1) << "Processing " << n->name(); - const string& container_name = n->attr().at("segment_funcdef_name").s(); - TRTCalibrationResource* cres = nullptr; - auto status = calib_rm->Lookup(container_name, "Calibrator", &cres); - if (!status.ok()) { - LOG(ERROR) << "Could not get Calibration information. Did you run with " - "calibration data?"; - return tensorflow::errors::FailedPrecondition( - "Need to run graph with calibration data first!"); - } - tensorflow::core::ScopedUnref calib_sc(cres); - if (cres->calibrator_) { - cres->calibrator_->waitAndSetDone(); - cres->thr_->join(); - const auto& calibration_table = - cres->calibrator_->getCalibrationTableAsString(); - if (calibration_table.empty()) { - LOG(ERROR) << "Calibration table is empty"; - return tensorflow::errors::Unknown( - "Calibration table is missing. This shouldn't have happened!"); - } - n->mutable_attr()->at("calibration_data").set_s(calibration_table); - } else { - LOG(ERROR) << "Can't get TRTCalibrator from resource manager!"; - return tensorflow::errors::Unknown( - "Can't get TRTCalibrator from resource manager!"); - } - TF_RETURN_IF_ERROR(calib_rm->Cleanup(container_name)); - } - } - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index 95cf0227dcf..80f68d36a3a 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -85,12 +85,6 @@ struct ConversionParams { std::vector cached_engine_batches; // list of cached engines }; -// This method extracts calibration information from the resource managers -// and puts them in to engine nodedefs. -tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def, - bool is_dyn_op); - // - max_batch_size: maximum batch size which can be used for inference for // optimization targets inference run with max batch size. // - max_workspace_size_bytes: The upper bound of memory allowance for engine diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 0d5b9851f79..002526c04bb 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" @@ -379,6 +378,32 @@ tensorflow::Status CreateBroadcastableScalarConstant( return Status::OK(); } +// Convert an axis from TF format to TRT format while validating. TF format +// includes the batch dimension, while TRT does not. TF can also use negative +// indices. +// TODO(tmorris): Use this method in more ops. +tensorflow::Status ConvertAxis(int tf_axis, int trt_nb_dims, + absl::string_view node_name, int* trt_axis) { + const int tf_nb_dims = trt_nb_dims + 1; + // Check bounds. + if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) { + return tensorflow::errors::InvalidArgument( + "Axis value of ", tf_axis, " is out of bounds, must be in range [", + -tf_nb_dims, ", ", tf_nb_dims, "), at ", node_name); + } + // Make negative axis positive. + if (tf_axis < 0) tf_axis += tf_nb_dims; + // Don't allow axis to be the batch dimension. + if (tf_axis == 0) { + return tensorflow::errors::Unimplemented( + "TensorRT does not allow manipulation of the batch dimension, at ", + node_name); + } + // Remove batch dimension. + *trt_axis = tf_axis - 1; + return Status::OK(); +} + inline bool DimsEqual(const nvinfer1::Dims& dim_l, const nvinfer1::Dims& dim_r) { if (dim_l.nbDims != dim_r.nbDims) { @@ -3413,6 +3438,29 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { return tensorflow::Status::OK(); } +tensorflow::Status ConvertGather(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"params", false}, {"indices", false}, {"axis", true}})); + absl::Span axis = inputs.at(2).weights().GetSpan(); + if (axis.size() != 1) { + return tensorflow::errors::InvalidArgument( + "Axis for GatherV2 must be a scalar, at ", node_def.name()); + } + int trt_axis = 0; + TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims, + node_def.name(), &trt_axis)); + if (params->validation_only) return Status::OK(); + + nvinfer1::IGatherLayer* layer = params->converter->network()->addGather( + *const_cast(inputs.at(0).tensor()), + *const_cast(inputs.at(1).tensor()), trt_axis); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return Status::OK(); +} + tensorflow::Status ConvertMatMulHelper(OpConverterParams* params, TRT_TensorOrWeights tensor_input, TRT_ShapedWeights weights_raw, @@ -3643,6 +3691,7 @@ static void RegisterValidatableOpConverters( (*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput; (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; (*registration)["ExpandDims"] = ConvertExpandDims; + (*registration)["GatherV2"] = ConvertGather; (*registration)["LeakyRelu"] = ConvertLeakyRelu; (*registration)["MatMul"] = ConvertMatMul; (*registration)["Pad"] = ConvertPad; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index d1e30eb848b..cbba01ba576 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -190,6 +190,11 @@ class TRT_ShapedWeights { string DebugString() const; + template + absl::Span GetSpan() const { + return absl::Span(tensor_.flat().data(), count()); + } + // TODO(aaroey): make these private. nvinfer1::Dims shape_; // Note: shape.type[] is not used. tensorflow::DataType type_; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index bb1341ada37..bb6fc7f0e48 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -3129,6 +3129,126 @@ TEST_F(OpConverterTest, ConvertTopK) { } } +template +void TestConvertGather(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + + // Get the NodeDef for GatherV2. + Scope s = Scope::NewRootScope(); + auto params = ops::Placeholder(s.WithOpName("params"), dtype); + auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); + auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); + auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis); + const NodeDef& node_def = gather.operation.node()->def(); + + struct TestParams { + std::vector params_dims; + std::vector indices_dims; + std::vector indices; + int axis; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Input is the same {1, 2, 3, 4, 5, 6} for all cases. + const int kGatherOKCases = 5; + TestParams ok_params[kGatherOKCases] = { + // Vector indices (output is rank(params)). + TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1}, {1, 4}}, + TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1}, {2, 5}}, + TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1}, {3, 6}}, + TestParams{{1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 3}, {3, 1, 2, 6, 4, 5}}, + // Higher rank indices (output is rank(params) + rank(indices) - 1). + TestParams{{1, 2, 3}, {1, 1}, {0}, 2, {1, 1, 1, 3}, {1, 2, 3}}, + }; + + // Ok. + for (int i = 0; i < kGatherOKCases; i++) { + test->Reset(); + test->AddTestTensor("params", ok_params[i].params_dims, 1, + TfDataTypeToTrt(dtype)); + test->AddTestTensor("indices", ok_params[i].indices_dims, 1, + nvinfer1::DataType::kINT32); + test->AddTestWeights("axis", {1}, {ok_params[i].axis}); + test->RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_gather", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + // Create input in CType and convert expected output to CType. + std::vector inputs = {CType(1), CType(2), CType(3), + CType(4), CType(5), CType(6)}; + std::vector converted_expected_output( + ok_params[i].expected_output.begin(), + ok_params[i].expected_output.end()); + + const DataVec input_data{ + {"params", test::AsTensor(inputs)}, + {"indices", test::AsTensor(ok_params[i].indices)}}; + DataVec output_data{ + {"my_gather", + ConstructTensor(ok_params[i].expected_output.size())}}; + test->BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(converted_expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertGather) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_gather", "GatherV2", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "GatherV2 got 0 inputs but expected 3, at my_gather"); + } + + // Get the NodeDef for GatherV2. + Scope s = Scope::NewRootScope(); + auto params = ops::Placeholder(s.WithOpName("params"), DT_FLOAT); + auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); + auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); + auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis); + const NodeDef& node_def = gather.operation.node()->def(); + { + // Axis is a tensor, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestTensor("axis", {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"axis\" for GatherV2 must be a constant, at my_gather"); + } + { + // Axis is out of bounds, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestWeights("axis", {1}, {4}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Axis value of 4 is out of bounds, must be in " + "range [-4, 4), at my_gather"); + } + { + // Axis is batch dimension, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestWeights("axis", {1}, {0}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "TensorRT does not allow manipulation of the " + "batch dimension, at my_gather"); + } + + Reset(); + TestConvertGather(this); + TestConvertGather(this); + TestConvertGather(this); +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index e3b31d736eb..f6d387c59cd 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -295,27 +294,6 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, return this->AllocateCalibrationResources(ctx, cr); }})); tensorflow::core::ScopedUnref calib_sc(calib_res); - // TODO(aaroey): here we also add the resource to the ResourceMgr singleton. - // This is needed before we migrate all uses of calib_graph_to_infer_graph() - // to the new calibration workflow. After that we'll remove this block. - { - auto deprecated_rm = - TRTResourceManager::instance()->getManager("TRTCalibration"); - TRTCalibrationResource* copied_resource = nullptr; - // Check whether the resource exists, and create it if not. - if (deprecated_rm->Lookup(funcdef_name_, "Calibrator", &copied_resource) - .ok()) { - // Do nothing if the resource exists. - copied_resource->Unref(); - } else { - copied_resource = calib_res; - // Increase the refcount by 1 then transfer the ownership of that refcount - // to the ResourceMgr singleton. - copied_resource->Ref(); - OP_REQUIRES_OK(ctx, deprecated_rm->Create(funcdef_name_, "Calibrator", - copied_resource)); - } - } int num_inputs = ctx->num_inputs(); // Pass input data to calibrator std::unordered_map input_data; diff --git a/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py index 7503d4d984e..25fb3a13db9 100644 --- a/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py +++ b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py @@ -38,16 +38,23 @@ def load_trt_ops(): if _trt_ops_so: return + try: + # pylint: disable=g-import-not-at-top,unused-variable + # This registers the TRT ops, it doesn't require loading TRT library. + from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op + # pylint: enable=g-import-not-at-top,unused-variable + except ImportError as e: + print("**** Failed to import TF-TRT ops. This is because the binary was " + "not built with CUDA or TensorRT enabled. ****") + raise e + # TODO(laigd): we should load TF-TRT kernels here as well after removing the # swig binding. try: - # TODO(lagid): It is not known why these unused imports were introduced. - # Investigate and get rid of these, if not required. - # pylint: disable=unused-import,g-import-not-at-top,unused-variable - from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op + # pylint: disable=g-import-not-at-top from tensorflow.python.framework import load_library from tensorflow.python.platform import resource_loader - # pylint: enable=unused-import,g-import-not-at-top,unused-variable + # pylint: enable=g-import-not-at-top _trt_ops_so = load_library.load_op_library( resource_loader.get_path_to_datafile("_trt_ops.so")) diff --git a/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc index 3bcca99afbf..dd3c09d7e42 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc @@ -19,7 +19,9 @@ limitations under the License. #include #include "re2/re2.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/compiler/tf2tensorrt/utils/test_utils.h b/tensorflow/compiler/tf2tensorrt/utils/test_utils.h index bcd628b62f0..d85875991b7 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/test_utils.h +++ b/tensorflow/compiler/tf2tensorrt/utils/test_utils.h @@ -16,8 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc deleted file mode 100644 index 0a72a88bc74..00000000000 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" -#include "tensorflow/core/platform/logging.h" - -namespace tensorflow { -namespace tensorrt { - -std::shared_ptr -tensorflow::tensorrt::TRTResourceManager::instance() { - static std::shared_ptr instance_(new TRTResourceManager); - return instance_; -} - -std::shared_ptr -tensorflow::tensorrt::TRTResourceManager::getManager(const string& op_name) { - // mutex is held for lookup only. Most instantiations where mutex will be held - // longer will be during op creation and should be ok. - tensorflow::mutex_lock lock(map_mutex_); - auto s = managers_.find(op_name); - if (s == managers_.end()) { - auto it = managers_.emplace( - op_name, std::make_shared(op_name)); - VLOG(1) << "Returning a new manager " << op_name; - return it.first->second; - } - VLOG(1) << "Returning old manager " << op_name; - return s->second; -} - -} // namespace tensorrt -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h b/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h deleted file mode 100644 index 03879ffff2f..00000000000 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ -#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ -#include - -#include -#include -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { -namespace tensorrt { - -class TRTResourceManager { - TRTResourceManager() = default; - - public: - static std::shared_ptr instance(); - // returns a manager for given op, if it doesn't exists it creates one - std::shared_ptr getManager(const string& op_name); - - private: - std::unordered_map> - managers_; - tensorflow::mutex map_mutex_; -}; - -} // namespace tensorrt -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 5a1a9435c19..7d9e7b9fc1f 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -24,7 +24,7 @@ package( ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") -load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library") cc_library( name = "tf2xla_supported_ops_lib", @@ -60,6 +60,14 @@ xla_proto_library( ], ) +xla_py_proto_library( + name = "tf2xla_py", + has_services = False, + api_version = 2, + visibility = ["//visibility:public"], + deps = [":tf2xla_proto"], +) + xla_proto_library( name = "host_compute_metadata_proto", srcs = ["host_compute_metadata.proto"], @@ -283,6 +291,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index b3f050c52b3..343568b2392 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -107,11 +107,13 @@ tf_kernel_library( "xla_pad_op.cc", "xla_reduce_op.cc", "xla_select_and_scatter_op.cc", + "xla_self_adjoint_eig_op.cc", ], hdrs = [ "index_ops.h", "shape_util.h", ], + tags = ["optonly"], deps = [ ":conv_op_helpers", ":if_op", @@ -143,6 +145,7 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/lib:qr", "//tensorflow/compiler/xla/client/lib:quantize", + "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/core:bitwise_ops_op_lib", "//tensorflow/core:control_flow_ops_op_lib", diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 31d4cc13160..280b68383c2 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -104,7 +104,7 @@ class SizeOp : public XlaOpKernel { for (int64 i = 0; i < rank; ++i) { size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i)); } - size = xla::ConvertElementType(size, xla::S32); + size = xla::ConvertElementType(size, ctx->output_xla_type(0)); ctx->SetOutput(0, size); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 65020012283..8958a48bc79 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" @@ -35,6 +36,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -69,6 +71,43 @@ class TensorListLengthOp : public XlaOpKernel { REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp); +// Creates an empty list with size (leading_dim, *element_shape) if +// element_shape is known at compile time. Otherwise creates one with size +// (leading_dim, 0) which gets initialized later in `GetInitializedList`. +Status CreateZerosList(XlaOpKernelContext* ctx, int element_shape_index, + int64 leading_dim, DataType dtype, xla::XlaOp* list) { + TensorShape list_shape; + list_shape.AddDim(leading_dim); + xla::XlaOp element_shape_handle = ctx->Input(element_shape_index); + TF_ASSIGN_OR_RETURN( + bool is_element_shape_compile_time_const, + element_shape_handle.builder()->IsConstant(element_shape_handle)); + PartialTensorShape partial_element_shape; + if (is_element_shape_compile_time_const) { + TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape( + element_shape_index, &partial_element_shape)); + } + if (is_element_shape_compile_time_const && + partial_element_shape.IsFullyDefined()) { + TensorShape element_shape; + partial_element_shape.AsTensorShape(&element_shape); + list_shape.AppendShape(element_shape); + } else { + // If element_shape is not a compile time constant or if it is not fully + // defined we will have to wait for the first write call to fully allocate + // the array. + // TODO(srbs): We are using element_shape of [0] as a proxy to denote an + // uninitialized list. A better implementation may be to represent the + // list as a 3-tuple containining an explicit "initialized" flag. However, + // we would still need to create a dummy tensor for the first tuple + // element. + list_shape.AddDim(0); + } + *list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype), + list_shape.dim_sizes()); + return Status::OK(); +} + class TensorListReserveOp : public XlaOpKernel { public: explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -76,20 +115,15 @@ class TensorListReserveOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - TensorShape element_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); int64 num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); - TensorShape tensor_shape; - tensor_shape.AddDim(num_elements); - tensor_shape.AppendShape(element_shape); + xla::XlaOp list; + OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &list)); xla::XlaBuilder* b = ctx->builder(); ctx->SetTensorListOutput( - 0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), - tensor_shape.dim_sizes()), - xla::ConstantR0(b, num_elements)})); + 0, xla::Tuple(b, {list, xla::ConstantR0(b, num_elements)})); } private: @@ -110,8 +144,6 @@ class EmptyTensorListOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - TensorShape element_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); int64 max_num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements)); OP_REQUIRES( @@ -119,15 +151,13 @@ class EmptyTensorListOp : public XlaOpKernel { errors::InvalidArgument("XLA compilation requires a fixed tensor list " "size. Set the max number of elements.")); - TensorShape tensor_shape; - tensor_shape.AddDim(max_num_elements); - tensor_shape.AppendShape(element_shape); + xla::XlaOp list; + OP_REQUIRES_OK(ctx, + CreateZerosList(ctx, 0, max_num_elements, dtype_, &list)); xla::XlaBuilder* b = ctx->builder(); ctx->SetTensorListOutput( - 0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), - tensor_shape.dim_sizes()), - xla::ConstantR0(b, 0)})); + 0, xla::Tuple(b, {list, xla::ConstantR0(b, 0)})); } private: @@ -274,6 +304,36 @@ REGISTER_XLA_OP( Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"), TensorListFromTensorOp); +// Returns the 0'th element of `tuple` containing the list tensor if it has been +// initialized already else creates one lazily. This allows lazy initialization +// of the list on the first call to SetItem or PushBack. +Status GetInitializedList(XlaOpKernelContext* ctx, const xla::XlaOp& tuple, + const TensorShape& element_shape, DataType dtype, + xla::XlaOp* list) { + *list = xla::GetTupleElement(tuple, 0); + TensorShape list_shape; + TF_RETURN_IF_ERROR(GetTensorListShape(ctx->builder(), tuple, &list_shape)); + int64 leading_dim = list_shape.dim_size(0); + TensorShape list_element_shape = list_shape; + list_element_shape.RemoveDim(0); + // This checks for the lazy initialization contract set by CreateEmptyList. + // In TensorListReserve if the element_shape is not known at compile time, + // it creates a list with shape [leading_dim, 0]. + if (element_shape != list_element_shape) { + if (list_element_shape.num_elements() != 0) { + return errors::InvalidArgument( + "Invalid shape of value in TensorListSetItem. Expected: ", + list_element_shape.DebugString(), + " Actual: ", element_shape.DebugString()); + } + list_shape = element_shape; + list_shape.InsertDim(0, leading_dim); + *list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype), + list_shape.dim_sizes()); + } + return Status::OK(); +} + class TensorListSetItemOp : public XlaOpKernel { public: explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -285,7 +345,9 @@ class TensorListSetItemOp : public XlaOpKernel { xla::XlaOp tl = ctx->Input(0); TensorShape elem_shape = ctx->InputShape(2); - xla::XlaOp ta = xla::GetTupleElement(tl, 0); + xla::XlaOp list; + OP_REQUIRES_OK(ctx, GetInitializedList(ctx, tl, elem_shape, dtype_, &list)); + xla::XlaOp index = ctx->Input(1); xla::XlaOp value = ctx->Input(2); @@ -299,8 +361,8 @@ class TensorListSetItemOp : public XlaOpKernel { auto update = xla::Reshape(value, slice_shape.dim_sizes()); ctx->SetTensorListOutput( - 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), - index + xla::ConstantR0(b, 1)})); + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices), + xla::GetTupleElement(tl, 1)})); } private: @@ -319,11 +381,14 @@ class TensorListPushBackOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp tl = ctx->Input(0); + xla::XlaOp list_tuple = ctx->Input(0); TensorShape elem_shape = ctx->InputShape(1); - xla::XlaOp ta = xla::GetTupleElement(tl, 0); - xla::XlaOp index = xla::GetTupleElement(tl, 1); + xla::XlaOp list; + OP_REQUIRES_OK( + ctx, GetInitializedList(ctx, list_tuple, elem_shape, dtype_, &list)); + + xla::XlaOp index = xla::GetTupleElement(list_tuple, 1); xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. @@ -336,7 +401,7 @@ class TensorListPushBackOp : public XlaOpKernel { auto update = xla::Reshape(value, slice_shape.dim_sizes()); ctx->SetTensorListOutput( - 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices), index + xla::ConstantR0(b, 1)})); } diff --git a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc new file mode 100644 index 00000000000..233ac8e7b45 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" +#include "tensorflow/core/lib/core/bits.h" + +namespace tensorflow { +namespace { + +class XlaSelfAdjointEigOp : public XlaOpKernel { + public: + explicit XlaSelfAdjointEigOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("lower", &lower_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_iter", &max_iter_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); + } + void Compile(XlaOpKernelContext* ctx) override { + auto result = + xla::SelfAdjointEig(ctx->Input(0), lower_, max_iter_, epsilon_); + ctx->SetOutput(0, result.w); + ctx->SetOutput(1, result.v); + } + + private: + bool lower_; + int32 max_iter_; + float epsilon_; +}; + +class SelfAdjointEigV2Op : public XlaOpKernel { + public: + explicit SelfAdjointEigV2Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape("input"); + int n = input_shape.dim_size(input_shape.dims() - 1); + // This is based on heuristics that approx log(n) sweep updates are needed. + // Note: the heuristics provides no theoretical guarantee, max_iter=100 and + // epsilon should be used to determine exit condition. + int max_iter = 2 * tensorflow::Log2Ceiling(n); + auto result = xla::SelfAdjointEig(ctx->Input(0), true, max_iter, 1e-6); + ctx->SetOutput(0, result.w); + ctx->SetOutput(1, result.v); + } +}; + +REGISTER_XLA_OP(Name("XlaSelfAdjointEig").TypeConstraint("T", kFloatTypes), + XlaSelfAdjointEigOp); +REGISTER_XLA_OP(Name("SelfAdjointEigV2").TypeConstraint("T", kFloatTypes), + SelfAdjointEigV2Op); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index af641131ed7..ccd58071d35 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -56,6 +56,41 @@ lhs_output: the broadcasted LHS tensor rhs_output: the broadcasted RHS tensor )doc"); +REGISTER_OP("XlaSelfAdjointEig") + .Input("a: T") + .Attr("lower: bool") + .Attr("max_iter: int") + .Attr("epsilon: float") + .Output("w: T") + .Output("v: T") + .SetShapeFn(shape_inference::UnknownShape) + .Attr("T: numbertype") + .Doc(R"doc( +Computes the eigen decomposition of a batch of self-adjoint matrices +(Note: Only real inputs are supported). + +Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in +tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for +i=0...N-1. + +a: the input tensor. + +lower: a boolean specifies whether the calculation is done with the lower + triangular part or the upper triangular part. + +max_iter: maximum number of sweep update, i.e., the whole lower triangular + part or upper triangular part based on parameter lower. Heuristically, it has + been argued that approximatly logN sweeps are needed in practice (Ref: Golub & + van Loan "Matrix Computation"). + +epsilon: the tolerance ratio. + +w: The eigenvalues in ascending order, each repeated according to its + multiplicity. +v: The column v[..., :, i] is the normalized eigenvector corresponding to the + eigenvalue w[..., i]. +)doc"); + REGISTER_OP("XlaConv") .Input("lhs: T") .Input("rhs: T") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 345193c936a..de4710d03a3 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -291,6 +291,10 @@ def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): name=name) +def self_adjoint_eig(a, lower, max_iter, epsilon): + return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon) + + dynamic_slice = gen_xla_ops.xla_dynamic_slice dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 08332645237..3221ec5b727 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -185,9 +185,10 @@ Status BuildComputation( std::vector elems; elems.reserve(retvals.size()); - // Keeps track of which retvals have layout to update. The first element is - // the output index, second element is the new layout. - std::vector> retval_to_update_layout; + // Keeps track of the layout of each retval. If a retval is not in this list, + // a descending layout is used. The first element is the output index, second + // element is the new layout. + std::vector> retval_index_and_layout; for (int i = 0; i < retvals.size(); ++i) { XlaCompiler::OutputDescription& output = (*outputs)[i]; const XlaExpression& retval = retvals[i]; @@ -216,7 +217,7 @@ Status BuildComputation( TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( output.shape, output.type)); value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); - retval_to_update_layout.emplace_back(elems.size(), shape.layout()); + retval_index_and_layout.emplace_back(elems.size(), shape.layout()); } else if (it != retval_cores.end()) { // Apply the sharding to the output, if there is a core assignment. value = identity_op(value); @@ -289,6 +290,11 @@ Status BuildComputation( // Ensures the correct sharding is applied to the output. handle = identity_op(handle); + // Set layout of the retval to device representation layout. + if (resource->representation_shape().has_value()) { + retval_index_and_layout.emplace_back( + elems.size(), resource->representation_shape()->layout()); + } elems.push_back(handle); } } @@ -318,15 +324,15 @@ Status BuildComputation( computation->GetProgramShape()); *output_shape = program_shape.result(); // Update the output layout to the layout of retval. - for (auto& update : retval_to_update_layout) { + for (auto& index_and_layout : retval_index_and_layout) { if (!always_return_tuple && elems.size() == 1) { - *output_shape->mutable_layout() = update.second; + *output_shape->mutable_layout() = index_and_layout.second; continue; } - xla::Shape* output_sub_shape = - xla::ShapeUtil::GetMutableSubshape(output_shape, {update.first}); - *output_sub_shape->mutable_layout() = update.second; + xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape( + output_shape, {index_and_layout.first}); + *output_sub_shape->mutable_layout() = index_and_layout.second; } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 492010f7317..b31137867d7 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -277,6 +277,97 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal)); } +// Tests that the compiler can correctly propagate the layout assigned by +// shape_representation_fn_ to return types. +TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 3}); + + auto options = DefaultOptions(); + options.shape_representation_fn = + [](const TensorShape& shape, DataType dt) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape)); + *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); + return xla_shape; + }; + // Compiles the graph. + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + xla::Shape transposed = + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}); + // Check that the return shapes are correctly tranposed. + EXPECT_EQ(result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({transposed, transposed})); +} + +// The layout of resource variable shouldn't change after transpose +TEST_F(XlaCompilerTest, TransposeVariables) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto transposed_read = ops::Transpose(scope, read, {1, 0}); + auto reshape = ops::Reshape(scope, transposed_read, {2, 3}); + auto d = ops::_Retval(scope.WithOpName("D"), reshape, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 3}); + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose", + std::move(graph), args, &result)); + xla::Shape transposed = + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0}); + // Check that the return shapes are correctly tranposed. + EXPECT_EQ(result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({transposed, transposed})); +} + // Tests that the compiler doesn't reorder the parameters. TEST_F(XlaCompilerTest, MixedOrderArguments) { for (bool swap_order : {false, true}) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 78bc2c94425..ee11f3a3de6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -319,6 +319,27 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { return Status::OK(); } +Status XlaOpKernelContext::ConstantInputAsPartialShape( + int index, PartialTensorShape* shape) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); + // If `literal` is a scalar it's value must be -1. + if (literal.shape().rank() == 0) { + int64 shape_val; + TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val)); + if (shape_val != -1) { + return errors::InvalidArgument( + "Cannot convert value to PartialTensorShape: ", shape_val); + } + *shape = PartialTensorShape(); // Shape with unknown rank. + return Status::OK(); + } + std::vector dims; + TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims)); + *shape = PartialTensorShape(dims); + return Status::OK(); +} + Status XlaOpKernelContext::InputList(absl::string_view name, std::vector* handles, std::vector* shapes) { @@ -447,6 +468,16 @@ void XlaOpKernelContext::SetOutputExpression(int index, } } +xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) { + xla::PrimitiveType type; + Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type); + if (!status.ok()) { + SetStatus(status); + return xla::PRIMITIVE_TYPE_INVALID; + } + return type; +} + void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { SetOutputExpression( index, @@ -503,6 +534,7 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, handle = xla::Reshape(handle, xla::AsInt64Slice(representation_shape.dimensions())); } + variable->SetRepresentationShape(representation_shape); return variable->SetValue(handle); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index e44415f60bf..cc2d5e8de3e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -138,6 +138,10 @@ class XlaOpKernelContext { // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); + // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1 + // into a PartialTensorShape. + Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape); + // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. @@ -155,6 +159,11 @@ class XlaOpKernelContext { return context_->expected_output_dtype(index); } + // Returns the type of output `index` as an xla::PrimitiveType. If the type + // is not representable as an XLA type, sets an error status and returns + // xla::PRIMITIVE_TYPE_INVALID. + xla::PrimitiveType output_xla_type(int index); + // Sets output `index` to the XlaOp `handle`. // All outputs should be set using SetOutput and SetConstantOutput, not // via the underlying OpKernelContext. diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 736588bb8b8..ab3a5bdd9bc 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -86,6 +86,12 @@ class XlaResource { // variables have new values that need to be written back. const xla::XlaOp& initial_value() const { return initial_value_; } + // An xla shape that indicates how this resource variable is represented on + // device. + const absl::optional& representation_shape() const { + return representation_shape_; + } + // A variable is initialized if it has a value. bool initialized() const { return value_.valid(); } @@ -100,6 +106,11 @@ class XlaResource { // Sets the current value of the resource to an all-zero value. Status SetZeroValue(xla::XlaBuilder* builder); + // Sets the representational shape of the resource on device. + void SetRepresentationShape(const xla::Shape& shape) { + representation_shape_ = absl::make_optional(shape); + } + // Looks up the gradient for `source`, or creates it if it does not already // exist. The call target must be an initialized TensorArray resource. A // TensorArray can have multiple named gradients; see the operator @@ -160,6 +171,10 @@ class XlaResource { xla::XlaOp value_; xla::XlaOp initial_value_; + // An xla shape that indicates how this resource variable is represented on + // device. + absl::optional representation_shape_; + int64 max_array_size_ = -1; bool tensor_array_multiple_writes_aggregate_ = false; diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 0abf546c14b..c5dea5f1803 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -452,11 +452,12 @@ cc_library( ) cc_library( - name = "self_adjoint_eigen", - srcs = ["self_adjoint_eigen.cc"], - hdrs = ["self_adjoint_eigen.h"], + name = "self_adjoint_eig", + srcs = ["self_adjoint_eig.cc"], + hdrs = ["self_adjoint_eig.h"], deps = [ ":arithmetic", + ":comparators", ":constants", ":loops", ":math", @@ -473,9 +474,12 @@ cc_library( ) xla_test( - name = "self_adjoint_eigen_test", - size = "medium", - srcs = ["self_adjoint_eigen_test.cc"], + name = "self_adjoint_eig_test", + srcs = ["self_adjoint_eig_test.cc"], + blacklisted_backends = [ + "cpu", + "gpu", + ], real_hardware_only = True, shard_count = 10, tags = ["optonly"], @@ -483,7 +487,7 @@ xla_test( ":arithmetic", ":constants", ":matrix", - ":self_adjoint_eigen", + ":self_adjoint_eig", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:literal", diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc similarity index 62% rename from tensorflow/compiler/xla/client/lib/self_adjoint_eigen.cc rename to tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc index c2c8caee6ea..546127e4627 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen.cc +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h" +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" #include #include #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/math.h" @@ -42,7 +43,6 @@ namespace { struct SymmetricSchurDecomposition { XlaOp c; // cosine. XlaOp s; // sine. - XlaOp reduction; // Reduction in the off diagonal after applying G. }; // JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix @@ -51,7 +51,11 @@ struct SymmetricSchurDecomposition { struct JacobiUpdate { XlaOp v; XlaOp w; +}; + +struct FrobeniusNorms { XlaOp off_diagonal_norm; + XlaOp total_norm; }; // Given an n-by-n symmetric A and integers p and q that satisfy 0 <= p < q < n, @@ -79,10 +83,6 @@ StatusOr SymmetricShurDecomposition2x2(XlaOp a, XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - PrimitiveType type = a_shape.element_type(); - - const int64 num_dims = a_shape.rank(); - auto zero = ScalarLike(a, 0.0); auto one = ScalarLike(a, 1.0); auto two = ScalarLike(a, 2.0); @@ -110,9 +110,7 @@ StatusOr SymmetricShurDecomposition2x2(XlaOp a, schur.c = c * rnorm; schur.s = s * rnorm; - schur.reduction = - Reduce(two * Square(pqs), zero, CreateScalarAddComputation(type, builder), - {num_dims - 2, num_dims - 1}); + return schur; } @@ -196,12 +194,32 @@ StatusOr Update(JacobiUpdate jacobi_update, XlaOp p, XlaOp q, jacobi_update.v = DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_q_new, {zero, q}); - jacobi_update.off_diagonal_norm = Sqrt( - Max(Square(jacobi_update.off_diagonal_norm) - schur.reduction, pq_zero)); - return jacobi_update; } +StatusOr ComputeFrobeniusNorms(XlaOp w) { + XlaBuilder* builder = w.builder(); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w)); + const int64 num_dims = shape.rank(); + auto frobenius_norm = + Sqrt(Reduce(Square(w), ScalarLike(w, 0.0), + CreateScalarAddComputation(shape.element_type(), builder), + {num_dims - 2, num_dims - 1})); + auto diag = GetMatrixDiagonal(w); + auto diag_square = + Reduce(Square(diag), ScalarLike(w, 0.0), + CreateScalarAddComputation(shape.element_type(), builder), + {num_dims - 2}); + + FrobeniusNorms frobenius_norms; + + frobenius_norms.off_diagonal_norm = + Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w, 0.0))); + frobenius_norms.total_norm = frobenius_norm; + + return frobenius_norms; +} + StatusOr> WhileLoopFn( absl::Span initial_values, // int matrix_dimension, // @@ -212,62 +230,108 @@ StatusOr> WhileLoopFn( auto while_cond_fn = [&](absl::Span values, XlaBuilder* cond_builder) -> StatusOr { auto k = values[0]; - auto off_diagonal_norm = values[5]; - // tol = frobenius_norm * epsilon. - auto tol = values[6] * values[7]; - auto max_sweeps = ScalarLike(k, max_sweep_updates); - auto sweep_update_cond = Gt(max_sweeps, k); - auto tol_cond = ReduceAll(Lt(tol, off_diagonal_norm), + auto norms = ComputeFrobeniusNorms(values[2]).ValueOrDie(); + auto tol = norms.total_norm * values[3]; + auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm), xla::ConstantR0(cond_builder, false), CreateScalarOrComputation(PRED, cond_builder)); - return And(tol_cond, sweep_update_cond); + + return And(sweep_update_cond, tol_cond); }; auto while_body_fn = [&](absl::Span values, XlaBuilder* body_builder) -> StatusOr> { - auto zero = Zero(body_builder, index_type); - auto one = One(body_builder, index_type); - auto end_index = ScalarLike(one, matrix_dimension); + auto while_cond_fn_inner = + [&](absl::Span values_inner, + XlaBuilder* inner_cond_builder) -> StatusOr { + auto p = values_inner[0]; + return Lt(p, ScalarLike(p, matrix_dimension - 1)); + }; + auto while_body_fn_inner = + [&](absl::Span values_inner, + XlaBuilder* inner_body_builder) -> StatusOr> { + auto while_cond_fn_innermost = + [&](absl::Span values_innermost, + XlaBuilder* innermost_cond_builder) -> StatusOr { + auto q = values_innermost[1]; + return Lt(q, ScalarLike(q, matrix_dimension)); + }; + auto while_body_fn_innermost = + [&](absl::Span values_innermost, + XlaBuilder* innermost_body_builder) + -> StatusOr> { + auto p = values_innermost[0]; + auto q = values_innermost[1]; + + JacobiUpdate jacobi_update; + jacobi_update.v = values_innermost[2]; + jacobi_update.w = values_innermost[3]; + + auto tol = values_innermost[4]; + + TF_ASSIGN_OR_RETURN(jacobi_update, + Update(jacobi_update, p, q, tol, matrix_dimension)); + + std::vector updated_values_innermost; + updated_values_innermost.reserve(values_innermost.size()); + + updated_values_innermost.push_back(p); + updated_values_innermost.push_back(q + ScalarLike(q, 1)); + updated_values_innermost.push_back(jacobi_update.v); + updated_values_innermost.push_back(jacobi_update.w); + updated_values_innermost.push_back(tol); + + return updated_values_innermost; + }; + + std::vector values_innermost(5); + auto p = values_inner[0]; + auto q = p + ScalarLike(p, 1); + values_innermost[0] = p; // index p. + values_innermost[1] = q; // index q. + values_innermost[2] = values_inner[1]; // v. + values_innermost[3] = values_inner[2]; // w. + values_innermost[4] = values_inner[3]; // tol. + TF_ASSIGN_OR_RETURN( + values_innermost, + WhileLoopHelper(while_cond_fn_innermost, while_body_fn_innermost, + values_innermost, absl::StrCat(name, "-Innermost"), + inner_body_builder)); + + std::vector updated_values_inner; + updated_values_inner.reserve(values_inner.size()); + + updated_values_inner.push_back(p + ScalarLike(p, 1)); + updated_values_inner.push_back(values_innermost[2]); + updated_values_inner.push_back(values_innermost[3]); + updated_values_inner.push_back(values_innermost[4]); + return updated_values_inner; + }; // Indexes. XlaOp k = values[0]; - XlaOp p = values[1]; - XlaOp q = values[2]; - JacobiUpdate jacobi_update; - jacobi_update.v = values[3]; - jacobi_update.w = values[4]; - jacobi_update.off_diagonal_norm = values[5]; - - XlaOp frobenius_norm = values[6]; - XlaOp tol = values[7]; - - TF_ASSIGN_OR_RETURN(jacobi_update, - Update(jacobi_update, p, q, tol, matrix_dimension)); + std::vector values_inner(4); + values_inner[0] = ScalarLike(k, 0); // index p. + values_inner[1] = values[1]; // v. + values_inner[2] = values[2]; // w. + values_inner[3] = values[3]; // tol. + TF_ASSIGN_OR_RETURN( + values_inner, + WhileLoopHelper(while_cond_fn_inner, while_body_fn_inner, values_inner, + absl::StrCat(name, "-Inner"), body_builder)); std::vector updated_values; - updated_values.reserve(values.size()); + updated_values.reserve(values_inner.size()); - q = q + one; - p = Select(Eq(q, end_index), p + one, p); - k = Select(Eq(p, end_index - one), k + one, k); - p = Select(Eq(p, end_index - one), zero, p); - q = Select(Eq(q, end_index), p + one, q); - - updated_values.push_back(k); - updated_values.push_back(p); - updated_values.push_back(q); - - updated_values.push_back(jacobi_update.v); - updated_values.push_back(jacobi_update.w); - updated_values.push_back(jacobi_update.off_diagonal_norm); - - updated_values.push_back(frobenius_norm); - updated_values.push_back(tol); + updated_values.push_back(k + ScalarLike(k, 1)); + updated_values.push_back(values_inner[1]); + updated_values.push_back(values_inner[2]); + updated_values.push_back(values_inner[3]); return updated_values; }; @@ -278,6 +342,27 @@ StatusOr> WhileLoopFn( return values; } +StatusOr SortByEigenvalues(SelfAdjointEigResult result) { + XlaBuilder* builder = result.v.builder(); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.v)); + const int64 num_dims = shape.rank(); + auto dimensions = shape.dimensions(); + + std::vector broadcast_dims(num_dims - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims[num_dims - 2] = num_dims - 1; + result.w = BroadcastInDim(result.w, dimensions, broadcast_dims); + + XlaOp sort_result = + Sort({result.w, result.v}, + CreateScalarLtComputation( + {shape.element_type(), shape.element_type()}, builder), + num_dims - 1); + result.w = GetMatrixDiagonal(GetTupleElement(sort_result, 0)); + result.v = GetTupleElement(sort_result, 1); + return result; +} + } // namespace // This is the cyclic Jacobi iteration. Please note that the eigenvalues are @@ -286,31 +371,35 @@ StatusOr> WhileLoopFn( // def jacobi(A): // n, _ = A.shape // V = np.eye(n) -// nfrob = np.sum(A ** 2) -// ndiag = np.sum(np.diag(A) ** 2) -// off = nfrob - ndiag -// while off > 1e-6 * nfrob: +// frobenius_norm = np.linalg.norm(A) +// diag_norm = np.linalg.norm(np.diag(A)) +// off_diag_norm = np.sqrt( +// frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm) +// while off_diag_norm > 1e-6 * frobenius_norm: // for p in range(n - 1): // for q in range(p + 1, n): -// if off > 1e-6 * nfrob: -// c, s = sym_schur2x2(A, p, q) -// off = off - 2 * A[p, q] ** 2 -// A[[p, q], :] = np.matmul(np.array([[c, -s], [s, c]]), -// A[[p, q], :]) -// A[:, [p, q]] = np.matmul(A[:, [p, q]], -// np.array([[c, s], [-s, c]])) -// V[:, [p, q]] = np.matmul(V[:, [p, q]], +// c, s = sym_schur2x2(A, p, q) +// A[[p, q], :] = np.matmul(np.array([[c, -s], [s, c]]), +// A[[p, q], :]) +// A[:, [p, q]] = np.matmul(A[:, [p, q]], +// np.array([[c, s], [-s, c]])) +// V[:, [p, q]] = np.matmul(V[:, [p, q]], // np.array([[c, s], [-s, c]])) +// frobenius_norm_sq = np.linalg.norm(A) +// diag_square_sum = np.linalg.norm(np.diag(A)) +// off_diag_norm = np.sqrt( +// frobenius_norm - diag_norm) * np.sqrt( +// frobenius_norm + diag_norm) // // return A, V // // TODO(kuny): Implement parallel order Jacobi. // -SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower, int64 max_iter, - float epsilon) { +SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter, + float epsilon) { XlaBuilder* builder = a.builder(); auto return_error = [&](const Status& status) { - SelfAdjointEigenResult result; + SelfAdjointEigResult result; result.v = builder->ReportError(status); result.w = builder->ReportError(status); return result; @@ -348,33 +437,17 @@ SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower, int64 max_iter, batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); } - auto zero = ScalarLike(a, 0.0); auto tol = ScalarLike(a, epsilon); auto v_init = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); auto w_init = Triangle(a, lower); w_init = w_init + TransposeInMinorDims(w_init) - w_init * v_init; - auto frobenius_norm = Sqrt(Reduce(Square(w_init), zero, - CreateScalarAddComputation(type, builder), - {num_dims - 2, num_dims - 1})); - auto diag = GetMatrixDiagonal(w_init); - auto diag_square = - Reduce(Square(diag), zero, CreateScalarAddComputation(type, builder), - {num_dims - 2}); - - auto off_diagonal_init = - Sqrt(Max(Square(frobenius_norm) - diag_square, zero)); - auto output_with_status = WhileLoopFn( { Zero(builder, S32), // k - Zero(builder, S32), // p - One(builder, S32), // q - v_init, // - w_init, // - off_diagonal_init, // - frobenius_norm, // + v_init, // v + w_init, // w tol, // }, // n, // @@ -388,11 +461,11 @@ SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower, int64 max_iter, auto output = output_with_status.ValueOrDie(); - SelfAdjointEigenResult result; - result.v = output[3]; - result.w = GetMatrixDiagonal(output[4]); + SelfAdjointEigResult result; + result.v = output[1]; + result.w = GetMatrixDiagonal(output[2]); - return result; + return SortByEigenvalues(result).ValueOrDie(); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h similarity index 71% rename from tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h rename to tensorflow/compiler/xla/client/lib/self_adjoint_eig.h index 49fc17aa275..2a089891d6a 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -23,20 +23,18 @@ namespace xla { // The eigenvalue decomposition of a symmetric matrix, the original matrix is // recovered by v * w * v_t. -struct SelfAdjointEigenResult { +struct SelfAdjointEigResult { // The i-th column is the normalized eigenvector corresponding to the // eigenvalue w[i]. Will return a matrix object if a is a matrix object. XlaOp v; - // TODO(kuny): Sort the eigenvalues. // The eigenvalues in ascending order, each repeated according to its // multiplicity. XlaOp w; }; -SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower = true, - int64 max_iter = 100, - float epsilon = 1e-6); +SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower = true, + int64 max_iter = 100, float epsilon = 1e-6); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen_test.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc similarity index 79% rename from tensorflow/compiler/xla/client/lib/self_adjoint_eigen_test.cc rename to tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc index 720c49b7716..c8875dff7bf 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen_test.cc +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h" +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" @@ -32,7 +32,7 @@ limitations under the License. namespace xla { -class SelfAdjointEigenTest : public ClientLibraryTestBase { +class SelfAdjointEigTest : public ClientLibraryTestBase { protected: void SetUp() override { ClientLibraryTestBase::SetUp(); @@ -71,7 +71,7 @@ class SelfAdjointEigenTest : public ClientLibraryTestBase { } void TearDown() override { ClientLibraryTestBase::TearDown(); } - Array3D get_unit_matrix_3d(const Array3D& matrix) { + Array3D GetUnitMatrix3D(const Array3D& matrix) { Array3D result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0); for (int i = 0; i < matrix.n1(); ++i) { for (int j = 0; j < matrix.n2(); ++j) { @@ -100,7 +100,7 @@ class SelfAdjointEigenTest : public ClientLibraryTestBase { return result; } - XlaOp ComputeMatmulVWVt(SelfAdjointEigenResult result, XlaBuilder* builder) { + XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) { Shape shape = builder->GetShape(result.v).ValueOrDie(); std::vector out_dims = shape.dimensions(); std::vector broadcast_dims(shape.rank() - 1); @@ -140,69 +140,69 @@ class SelfAdjointEigenTest : public ClientLibraryTestBase { Array2D wrong_type_4x4_; }; -XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_2x4x4) { +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) { XlaBuilder builder(TestName()); XlaOp a; auto a_data = CreateR3Parameter(batch_3d_4x4_, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); ComputeMatmulVWVt(result, &builder); ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_Lower_2x4x4) { +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) { XlaBuilder builder(TestName()); XlaOp a; auto a_data = CreateR3Parameter( ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); ComputeMatmulVWVt(result, &builder); ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_Upper_2x4x4) { +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) { XlaBuilder builder(TestName()); XlaOp a; auto a_data = CreateR3Parameter( ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a, false); + auto result = SelfAdjointEig(a, false); ComputeMatmulVWVt(result, &builder); ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Test_Orthogonality_2x4x4) { +XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) { XlaBuilder builder(TestName()); XlaOp a; auto a_data = CreateR3Parameter(batch_3d_4x4_, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST); - ComputeAndCompareR3(&builder, get_unit_matrix_3d(batch_3d_4x4_), + ComputeAndCompareR3(&builder, GetUnitMatrix3D(batch_3d_4x4_), {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) { +XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) { XlaBuilder builder(TestName()); XlaOp a; auto a_data = CreateR2Parameter(low_rank_4x4_, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); ComputeMatmulVWVt(result, &builder); ComputeAndCompareR2(&builder, low_rank_4x4_, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Test_Eigen_8x8) { +XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) { XlaBuilder builder(TestName()); // This is computed by numpy.linalg.eigh with float32. @@ -211,21 +211,21 @@ XLA_TEST_F(SelfAdjointEigenTest, Test_Eigen_8x8) { XlaOp a; auto a_data = CreateR2Parameter(matrix2d_8x8_, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); - Sort(result.w); + auto result = SelfAdjointEig(a); + Add(result.w, ZerosLike(result.w)); ComputeAndCompareR1(&builder, expected, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Test_Orthogonality_8x8) { +XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) { XlaBuilder builder(TestName()); float expected_vals = 1e-3; XlaOp a; auto a_data = CreateR2Parameter(matrix2d_8x8_, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); // np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2 GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8), BatchDot(TransposeInMinorDims(result.v), result.v), @@ -235,66 +235,79 @@ XLA_TEST_F(SelfAdjointEigenTest, Test_Orthogonality_8x8) { ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Wrong_Type_Int) { +XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) { XlaBuilder builder(TestName()); XlaOp a; auto a_data = CreateR2Parameter(wrong_type_4x4_, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); EXPECT_FALSE(result.v.valid()); EXPECT_FALSE(result.w.valid()); } -XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_8x8) { +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_8x8) { XlaBuilder builder(TestName()); int size = 8; Array2D a_val = GenerateRandomSymmetricMatrix(size); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_16x16) { +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_16x16) { XlaBuilder builder(TestName()); int size = 16; Array2D a_val = GenerateRandomSymmetricMatrix(size); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_32x32) { +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_32x32) { XlaBuilder builder(TestName()); int size = 32; Array2D a_val = GenerateRandomSymmetricMatrix(size); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_64x64) { +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_256x256) { XlaBuilder builder(TestName()); - int size = 64; + int size = 256; Array2D a_val = GenerateRandomSymmetricMatrix(size); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SelfAdjointEigen(a); + auto result = SelfAdjointEig(a); GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, - ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_512x512) { + XlaBuilder builder(TestName()); + int size = 512; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index 3245f46e6fd..ddc39f4d874 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -36,7 +36,8 @@ XlaOp TopK(XlaOp input, int64 k) { XlaOp sort_result = Sort({Neg(input), iota_s32}, CreateScalarLtComputation({input_shape.element_type(), S32}, - iota_s32.builder())); + iota_s32.builder()), + last_dim, /*is_stable=*/true); std::vector start_indices(input_shape.dimensions_size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); limit_indices[last_dim] = k; diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index ae78910a5b4..0fbd138aca1 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -81,9 +81,7 @@ XLA_TEST_F(SortingTest, TopKFullSort) { ComputeAndCompareR1(&builder, inputs, {}); } -// TODO(b/122298745): Enable this test when the GPU backend supports stable -// sorting. -XLA_TEST_F(SortingTest, DISABLED_ON_GPU(TopKFullSortWithDuplicates)) { +XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates) { XlaBuilder builder(TestName()); XlaOp a; auto a_data = CreateR1Parameter({1, 1, 2, 2, 1}, 0, "a", &builder, &a); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index fb9dbe851e7..b371b5af37b 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1663,14 +1663,16 @@ XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span values, Lt(first_lhs_param, first_rhs_param); TF_ASSIGN_OR_RETURN(auto comparator, b->Build()); - return Sort(operands, comparator, dimension); + return Sort(operands, comparator, dimension, /*is_stable=*/false); }); } XlaOp XlaBuilder::Sort(absl::Span operands, - const XlaComputation& comparator, int64 dimension) { + const XlaComputation& comparator, int64 dimension, + bool is_stable) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; + instr.set_is_stable(is_stable); std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(std::vector operand_shapes, GetOperandShapes(operands)); @@ -3320,8 +3322,9 @@ XlaOp Sort(const XlaOp& keys, absl::Span values, int64 dimension) { } XlaOp Sort(absl::Span operands, const XlaComputation& comparator, - int64 dimension) { - return operands[0].builder()->Sort(operands, comparator, dimension); + int64 dimension, bool is_stable) { + return operands[0].builder()->Sort(operands, comparator, dimension, + is_stable); } XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 1e39c8766f3..fd2e9816e8a 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -505,7 +505,7 @@ class XlaBuilder { XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); XlaOp Sort(absl::Span operands, const XlaComputation& comparator, - int64 dimension = -1); + int64 dimension = -1, bool is_stable = false); XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -923,7 +923,8 @@ class XlaBuilder { friend XlaOp Sort(const XlaOp& keys, absl::Span values, int64 dimension); friend XlaOp Sort(absl::Span operands, - const XlaComputation& comparator, int64 dimension); + const XlaComputation& comparator, int64 dimension, + bool is_stable); friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); friend XlaOp Map(XlaBuilder* builder, absl::Span operands, const XlaComputation& computation, @@ -1695,7 +1696,8 @@ XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); // Enqueues a sort instruction onto the computation, using 'comparator' for -// comparisons. 'comparator' needs to define a strict weak order. +// comparisons. 'comparator' needs to define a strict weak order. 'is_stable' +// determines whether the stable sorting should be used. // If only one operand is provided: // * If the operand is a rank-1 tensor (an array), the result is a sorted array. // The resulting sorting order has the property that for all index positions @@ -1718,7 +1720,7 @@ XlaOp Sort(const XlaOp& keys, absl::Span values = {}, // correspond to the value of operand i at two index positions. // Default comparator computations can be found in lib/comparators.h XlaOp Sort(absl::Span operands, const XlaComputation& comparator, - int64 dimension = -1); + int64 dimension = -1, bool is_stable = false); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 0f9b591c70d..230f3b202a4 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -77,7 +77,7 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const { } ExecutableRunOptions& ExecutableRunOptions::set_device_assignment( - DeviceAssignment* device_assignment) { + const DeviceAssignment* device_assignment) { device_assignment_ = device_assignment; return *this; } diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 6f36d11dfb3..1e744953bd3 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -74,7 +74,7 @@ class ExecutableRunOptions { ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile); ExecutableRunOptions& set_device_assignment( - DeviceAssignment* device_assignment); + const DeviceAssignment* device_assignment); const DeviceAssignment* device_assignment() const; ExecutableRunOptions& set_rng_seed(int rng_seed); @@ -83,7 +83,7 @@ class ExecutableRunOptions { private: DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; - DeviceAssignment* device_assignment_ = nullptr; + const DeviceAssignment* device_assignment_ = nullptr; stream_executor::Stream* stream_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index f7e2d26b7aa..a0687e0d523 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -77,6 +77,7 @@ cc_library( "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:qr", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 24138a173de..671953aefe1 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -53,74 +54,6 @@ namespace swig { // TODO(b/118641336): Factor out XRT parts into a small c++ library of their // own. -// TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of -// device handles instead of needing to set the number of replicas at XLA -// service initialization time. -tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED); -int g_replica_count GUARDED_BY(g_local_client_mutex) = 1; -LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr; - -string* GetPlatformNameString() { - static string* platform_name_string PT_GUARDED_BY(g_local_client_mutex) = - new string("Host"); - return platform_name_string; -} - -Status InitializeReplicaCount(int replica_count) { - if (replica_count < 1) { - return InvalidArgument("Replica count must be >= 1; got %d.", - replica_count); - } - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return FailedPrecondition( - "Attempted to set the replica count to %d, but a local XLA service was " - "previously created with a replica count of %d.", - replica_count, g_replica_count); - } - g_replica_count = replica_count; - return Status::OK(); -} - -Status InitializePlatformName(const string& platform_name) { - string* g_platform_name = GetPlatformNameString(); - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return FailedPrecondition( - "Attempted to set the platform name to %s, but a local XLA service was " - "previously created with a platform name of %s.", - platform_name, *g_platform_name); - } - TF_ASSIGN_OR_RETURN(se::Platform * platform, - PlatformUtil::GetPlatform(platform_name)); - if (platform->VisibleDeviceCount() <= 0) { - return InvalidArgument("Platform %s has no visible devices.", - platform_name); - } - *g_platform_name = platform_name; - return Status::OK(); -} - -int GetReplicaCount() { - tensorflow::mutex_lock lock(g_local_client_mutex); - return g_replica_count; -} - -StatusOr GetOrCreateLocalClient() { - string* platform_name = GetPlatformNameString(); - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return g_local_client; - } - LocalClientOptions options; - options.set_platform(PlatformUtil::GetPlatform(*platform_name).ValueOrDie()); - options.set_number_of_replicas(g_replica_count); - TF_ASSIGN_OR_RETURN(g_local_client, - ClientLibrary::GetOrCreateLocalClient(options)); - CHECK(g_local_client != nullptr); - return g_local_client; -} - Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) { const char* name = "xla._CPU_CUSTOM_CALL_TARGET"; if (!PyCapsule_IsValid(capsule, name)) { @@ -135,62 +68,66 @@ Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) { return Status::OK(); } -Status TransferToInfeedLocal(const Literal& literal) { - VLOG(1) << "Infeeding literal without replica number; shape: " - << literal.shape(); - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0); +LocalClient::LocalClient(xla::LocalClient* client) : client_(client) {} + +/* static */ StatusOr LocalClient::Get( + const string& platform_name) { + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform(platform_name)); + if (platform->VisibleDeviceCount() <= 0) { + return InvalidArgument("Platform %s has no visible devices.", + platform_name); + } + LocalClientOptions options; + options.set_platform(platform); + TF_ASSIGN_OR_RETURN(xla::LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(options)); + CHECK(client != nullptr); + return LocalClient(client); } -Status TransferToInfeedLocalReplica(const Literal& literal, - int replica_number) { - VLOG(1) << "Infeeding shape " << literal.shape() - << " to replica number: " << replica_number; - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - return client->TransferToInfeedLocal(literal, device_ordinal); +// Returns the number of devices known to the XLA client. +int LocalClient::DeviceCount() const { return client_->device_count(); } + +Status LocalClient::TransferToInfeed(const Literal& literal, + int device_ordinal) { + VLOG(1) << "Infeeding literal to device " << device_ordinal + << "; shape: " << literal.shape(); + return client_->TransferToInfeed(literal, device_ordinal); } -StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, - int replica_number) { - VLOG(1) << "Outfeeding literal from replica number: " << replica_number - << " shape: " << shape; - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - return client->TransferFromOutfeedLocal(shape, device_ordinal); -} - -static StatusOr ToBuffer(LocalClient* client, - int device_ordinal, - const Literal& arg) { - return client->LiteralToShapedBuffer(arg, device_ordinal, - client->backend().memory_allocator()); +StatusOr LocalClient::TransferFromOutfeed(const Shape& shape, + int device_ordinal) { + VLOG(1) << "Outfeeding literal from device " << device_ordinal + << "; shape: " << shape; + return client_->TransferFromOutfeed(&shape, device_ordinal); } /* static */ StatusOr LocalShapedBuffer::FromLiteral( const Literal& argument, const absl::optional& shape_with_layout, - int replica_number) { - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - VLOG(1) << "Creating shaped buffer from literal on replica/ordinal: " - << replica_number << "/" << device_ordinal; + const LocalClient& client, int device_ordinal) { + VLOG(1) << "Creating shaped buffer from literal on device ordinal: " + << device_ordinal; + auto literal_to_buffer = [&](const Literal& arg) { + return client.client()->LiteralToShapedBuffer( + arg, device_ordinal, client.client()->backend().memory_allocator()); + }; + StatusOr buf = [&] { if (shape_with_layout) { Literal relaid = argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, device_ordinal, relaid); + return literal_to_buffer(relaid); } - return ToBuffer(client, device_ordinal, argument); + return literal_to_buffer(argument); }(); TF_RETURN_IF_ERROR(buf.status()); - return new LocalShapedBuffer(std::move(buf).ValueOrDie()); + return new LocalShapedBuffer(std::move(buf).ValueOrDie(), client.client()); } -LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer) - : shaped_buffer_(std::move(shaped_buffer)) {} +LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, + xla::LocalClient* client) + : shaped_buffer_(std::move(shaped_buffer)), client_(client) {} const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const { return &shaped_buffer_; @@ -203,8 +140,7 @@ const Shape& LocalShapedBuffer::shape() const { } StatusOr LocalShapedBuffer::ToLiteral() const { - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - return client->ShapedBufferToLiteral(*shaped_buffer()); + return client_->ShapedBufferToLiteral(*shaped_buffer()); } LocalShapedBufferTuple::LocalShapedBufferTuple( @@ -235,6 +171,51 @@ StatusOr LocalShapedBufferTuple::Release(int i) { int64 LocalShapedBufferTuple::size() const { return elements_.size(); } +StatusOr LocalShapedBuffer::DestructureTuple() { + const Shape tuple_shape = shape(); + + if (!tuple_shape.IsTuple()) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString(tuple_shape)); + } + + DeviceMemoryAllocator* allocator = shaped_buffer()->memory_allocator(); + ShapedBuffer tuple_buffer = Release(); + + // Extract some metadata we use to construct scoped buffers. + const se::Platform* platform = tuple_buffer.platform(); + int device_ordinal = tuple_buffer.device_ordinal(); + + ShapeTree& shape_tree = tuple_buffer.buffers(); + std::vector results; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + // Create a shaped buffer for this destructured tuple element. + const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); + VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; + ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); + + ShapeUtil::ForEachSubshape( + subshape, [&](const Shape& s, const ShapeIndex& index) { + ShapeIndex original(index); + original.push_front(i); + se::DeviceMemoryBase* device_memory = + shape_tree.mutable_element(original); + shaped_buffer.set_buffer(*device_memory, index); + *device_memory = se::DeviceMemoryBase(); + }); + + VLOG(3) << "Completed tuple element: " << i; + results.push_back(new LocalShapedBuffer( + ScopedShapedBuffer(std::move(shaped_buffer), allocator), client_)); + } + // Deallocate the root buffer. + se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); + TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); + return new LocalShapedBufferTuple(std::move(results)); +} + XrtAllocation::XrtAllocation(int64 handle, Shape shape, const string& session_target) : handle_(handle), shape_(shape), session_target_(session_target) {} @@ -332,23 +313,32 @@ StatusOr XrtAllocationTuple::Release(int i) { int64 XrtAllocationTuple::size() const { return elements_.size(); } -CompiledLocalComputation::CompiledLocalComputation( - std::unique_ptr executable) - : executable_(std::move(executable)) {} +LocalExecutable::LocalExecutable( + std::unique_ptr executable, + xla::DeviceAssignment device_assignment, xla::LocalClient* client) + : executable_(std::move(executable)), + device_assignment_(std::move(device_assignment)), + client_(client) {} -StatusOr CompiledLocalComputation::Execute( +std::vector LocalExecutable::DeviceOrdinals() const { + int num_replicas = device_assignment_.replica_count(); + std::vector device_ordinals; + device_ordinals.reserve(num_replicas); + for (int i = 0; i < num_replicas; ++i) { + device_ordinals.push_back(device_assignment_(i, 0)); + } + return device_ordinals; +} + +StatusOr LocalExecutable::Execute( absl::Span argument_handles) { if (num_replicas() != 1) { return InvalidArgument( "Attempted to execute computation with %d replicas using Execute()", num_replicas()); } - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - client->backend().computation_placer()->AssignDevices( - 1, /*computation_count=*/1)); StatusOr result_buffer_status; - const int device_ordinal = device_assignment(0, 0); + const int device_ordinal = device_assignment_(0, 0); VLOG(3) << "Replica 0 mapped to device ordinal for execution: " << device_ordinal; @@ -360,10 +350,10 @@ StatusOr CompiledLocalComputation::Execute( ExecutableRunOptions options; options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); + options.set_allocator(client_->backend().memory_allocator()); options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); + client_->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment_); result_buffer_status = executable_->Run(argument_buffers, options); @@ -373,13 +363,13 @@ StatusOr CompiledLocalComputation::Execute( "%s.", result_buffer_status.status().ToString()); } - return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie()); + return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(), + client_); } -StatusOr CompiledLocalComputation::ExecutePerReplica( +StatusOr LocalExecutable::ExecutePerReplica( absl::Span> argument_handles) { - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - const int num_devices = client->device_count(); + const int num_devices = client_->device_count(); if (argument_handles.size() != num_replicas()) { return InvalidArgument( @@ -394,14 +384,9 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( VLOG(1) << "Executing with " << num_replicas() << " replicas."; - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - client->backend().computation_placer()->AssignDevices( - num_replicas(), /*computation_count=*/1)); - std::vector> results(num_replicas()); - auto execute = [this, client, &device_assignment, &argument_handles, - &results](int replica) { - const int device_ordinal = device_assignment(replica, 0); + auto execute = [this, &argument_handles, &results](int replica) { + const int device_ordinal = device_assignment_(replica, 0); VLOG(3) << "Replica " << replica << " mapped to device ordinal for execution: " << device_ordinal; @@ -413,10 +398,10 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( ExecutableRunOptions options; options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); + options.set_allocator(client_->backend().memory_allocator()); options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); + client_->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment_); StatusOr result_buffer_status = executable_->Run(argument_buffers, options); @@ -448,26 +433,19 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( replica, statusor.status().ToString()); } wrapped_results[replica] = - new LocalShapedBuffer(std::move(statusor).ValueOrDie()); + new LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_); } return new LocalShapedBufferTuple(std::move(wrapped_results)); } -static StatusOr GetReturnValueShape(const XlaComputation& computation) { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - computation.GetProgramShape()); - return std::move(*program_shape.mutable_result()); -} - -CompiledXrtComputation::CompiledXrtComputation( - const ProgramShape& program_shape, int64 handle, - const string& session_target) +XrtExecutable::XrtExecutable(const ProgramShape& program_shape, int64 handle, + const string& session_target) : program_shape_(program_shape), handle_(handle), session_target_(session_target) {} -CompiledXrtComputation::~CompiledXrtComputation() { +XrtExecutable::~XrtExecutable() { tensorflow::Scope root = tensorflow::Scope::NewRootScope(); auto computation_handle = tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); @@ -489,7 +467,7 @@ CompiledXrtComputation::~CompiledXrtComputation() { } } -StatusOr CompiledXrtComputation::Execute( +StatusOr XrtExecutable::Execute( absl::Span argument_handles) { const int num_expected_arguments = program_shape().parameters().size(); @@ -528,36 +506,41 @@ StatusOr CompiledXrtComputation::Execute( return new XrtAllocation(output, program_shape().result(), session_target_); } -const ProgramShape& CompiledXrtComputation::program_shape() const { +const ProgramShape& XrtExecutable::program_shape() const { return program_shape_; } -int64 CompiledXrtComputation::handle() const { return handle_; } +int64 XrtExecutable::handle() const { return handle_; } -LocalComputation::LocalComputation(XlaComputation computation) +Computation::Computation(XlaComputation computation) : computation_(std::move(computation)) {} -StatusOr LocalComputation::Compile( +StatusOr Computation::Compile( const std::vector& argument_shapes, - const ExecutableBuildOptions* build_options) { + const ExecutableBuildOptions* build_options, const LocalClient& client) { std::vector argument_shape_pointers; argument_shape_pointers.reserve(argument_shapes.size()); for (auto& argument_shape : argument_shapes) { argument_shape_pointers.push_back(&argument_shape); } - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); ExecutableBuildOptions options; if (build_options != nullptr) { options = *build_options; } TF_ASSIGN_OR_RETURN( auto local_executable, - client->Compile(computation_, argument_shape_pointers, options)); - return new CompiledLocalComputation(std::move(local_executable)); + client.client()->Compile(computation_, argument_shape_pointers, options)); + TF_ASSIGN_OR_RETURN( + DeviceAssignment device_assignment, + client.client()->backend().computation_placer()->AssignDevices( + options.num_replicas(), /*computation_count=*/1)); + + return new LocalExecutable(std::move(local_executable), + std::move(device_assignment), client.client()); } -StatusOr LocalComputation::CompileForXrt( +StatusOr Computation::CompileForXrt( const std::vector& argument_shapes, const string& session_target) { tensorflow::Scope root = tensorflow::Scope::NewRootScope(); auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); @@ -585,14 +568,12 @@ StatusOr LocalComputation::CompileForXrt( TF_ASSIGN_OR_RETURN(ProgramShape program_shape, computation().GetProgramShape()); int64 handle = outputs[0].scalar()(); - return new CompiledXrtComputation(program_shape, handle, session_target); + return new XrtExecutable(program_shape, handle, session_target); } -const XlaComputation& LocalComputation::computation() const { - return computation_; -} +const XlaComputation& Computation::computation() const { return computation_; } -string LocalComputation::GetSerializedProto() const { +string Computation::GetSerializedProto() const { string result; if (!computation_.proto().SerializeToString(&result)) { LOG(ERROR) << "Failed to serialize the HloModuleProto."; @@ -601,101 +582,103 @@ string LocalComputation::GetSerializedProto() const { return result; } -StatusOr LocalComputation::GetReturnValueShape() const { - return swig::GetReturnValueShape(computation_); +StatusOr Computation::GetProgramShape() const { + return computation_.GetProgramShape(); +} + +StatusOr Computation::GetReturnValueShape() const { + TF_ASSIGN_OR_RETURN(ProgramShape shape, computation_.GetProgramShape()); + return std::move(*shape.mutable_result()); } LocalOp::LocalOp(const XlaOp& op) : op_(op) {} const XlaOp& LocalOp::op() const { return op_; } -LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) +ComputationBuilder::ComputationBuilder(const string& computation_name) : builder_(computation_name) {} -void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { +void ComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { builder_.SetOpMetadata(metadata); } -void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } +void ComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } -StatusOr LocalComputationBuilder::Build() { +StatusOr ComputationBuilder::Build() { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { +LocalOp ComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, const string& name) { return xla::Parameter(&builder_, parameter_number, shape, name); } -StatusOr LocalComputationBuilder::BuildWithRoot( - const LocalOp& root) { +StatusOr ComputationBuilder::BuildWithRoot(const LocalOp& root) { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build(root.op())); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -StatusOr LocalComputationBuilder::GetShape(const LocalOp& operand) { +StatusOr ComputationBuilder::GetShape(const LocalOp& operand) { return builder_.GetShape(operand.op()); } -StatusOr LocalComputationBuilder::GetReturnValueShape() { +StatusOr ComputationBuilder::GetReturnValueShape() { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape()); return program_shape.result(); } -LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { +LocalOp ComputationBuilder::Infeed(const Shape& shape) { return xla::Infeed(&builder_, shape); } -void LocalComputationBuilder::Outfeed(const LocalOp& operand, - const Shape& shape, - const string& outfeed_config) { +void ComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, + const string& outfeed_config) { xla::Outfeed(operand.op(), shape, outfeed_config); } -LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { +LocalOp ComputationBuilder::ConstantLiteral(const Literal& literal) { return xla::ConstantLiteral(&builder_, literal); } -LocalOp LocalComputationBuilder::Iota(PrimitiveType element_type, int64 size) { +LocalOp ComputationBuilder::Iota(PrimitiveType element_type, int64 size) { return xla::Iota(&builder_, element_type, size); } -LocalOp LocalComputationBuilder::BroadcastedIota(const Shape& shape, - int64 dimension) { +LocalOp ComputationBuilder::BroadcastedIota(const Shape& shape, + int64 dimension) { return xla::Iota(&builder_, shape, dimension); } -LocalOp LocalComputationBuilder::Broadcast( - const LocalOp& operand, absl::Span broadcast_sizes) { +LocalOp ComputationBuilder::Broadcast(const LocalOp& operand, + absl::Span broadcast_sizes) { return xla::Broadcast(operand.op(), broadcast_sizes); } -LocalOp LocalComputationBuilder::BroadcastInDim( +LocalOp ComputationBuilder::BroadcastInDim( const LocalOp& operand, absl::Span out_dim_sizes, absl::Span broadcast_dimensions) { return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions); } -LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, - const LocalOp& padding_value, - const PaddingConfig& padding_config) { +LocalOp ComputationBuilder::Pad(const LocalOp& operand, + const LocalOp& padding_value, + const PaddingConfig& padding_config) { return xla::Pad(operand.op(), padding_value.op(), padding_config); } -LocalOp LocalComputationBuilder::Reshape(const LocalOp& operand, - absl::Span dimensions, - absl::Span new_sizes) { +LocalOp ComputationBuilder::Reshape(const LocalOp& operand, + absl::Span dimensions, + absl::Span new_sizes) { return xla::Reshape(operand.op(), dimensions, new_sizes); } -LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand, - absl::Span dimensions) { +LocalOp ComputationBuilder::Collapse(const LocalOp& operand, + absl::Span dimensions) { return xla::Collapse(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::AllToAll( +LocalOp ComputationBuilder::AllToAll( const LocalOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, absl::Span replica_groups) { std::vector rg(replica_groups.size()); @@ -706,39 +689,38 @@ LocalOp LocalComputationBuilder::AllToAll( split_count, rg); } -LocalOp LocalComputationBuilder::CrossReplicaSum( +LocalOp ComputationBuilder::CrossReplicaSum( const LocalOp& operand, absl::Span replica_groups) { return xla::CrossReplicaSum(operand.op(), replica_groups); } -LocalOp LocalComputationBuilder::Slice(const LocalOp& operand, - absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides) { +LocalOp ComputationBuilder::Slice(const LocalOp& operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return xla::Slice(operand.op(), start_indices, limit_indices, strides); } -LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, - int64 start_index, - int64 limit_index, int64 stride, - int64 dimno) { +LocalOp ComputationBuilder::SliceInDim(const LocalOp& operand, + int64 start_index, int64 limit_index, + int64 stride, int64 dimno) { return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno); } -LocalOp LocalComputationBuilder::DynamicSlice( - const LocalOp& operand, const LocalOp& start_indices, - absl::Span slice_sizes) { +LocalOp ComputationBuilder::DynamicSlice(const LocalOp& operand, + const LocalOp& start_indices, + absl::Span slice_sizes) { return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } -LocalOp LocalComputationBuilder::DynamicUpdateSlice( - const LocalOp& operand, const LocalOp& update, - const LocalOp& start_indices) { +LocalOp ComputationBuilder::DynamicUpdateSlice(const LocalOp& operand, + const LocalOp& update, + const LocalOp& start_indices) { return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op()); } -LocalOp LocalComputationBuilder::ConcatInDim(absl::Span operands, - int64 dimension) { +LocalOp ComputationBuilder::ConcatInDim(absl::Span operands, + int64 dimension) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -747,18 +729,18 @@ LocalOp LocalComputationBuilder::ConcatInDim(absl::Span operands, return xla::ConcatInDim(&builder_, xla_ops, dimension); } -LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( - const LocalOp& operand, const LocalComputation& select, +LocalOp ComputationBuilder::SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const Computation& select, absl::Span window_dimensions, absl::Span window_strides, absl::Span> padding, const LocalOp& source, - const LocalOp& init_value, const LocalComputation& scatter) { + const LocalOp& init_value, const Computation& scatter) { return xla::SelectAndScatterWithGeneralPadding( operand.op(), select.computation(), window_dimensions, window_strides, padding, source.op(), init_value.op(), scatter.computation()); } -LocalOp LocalComputationBuilder::Tuple(absl::Span elements) { +LocalOp ComputationBuilder::Tuple(absl::Span elements) { std::vector xla_ops; xla_ops.reserve(elements.size()); for (const auto& op : elements) { @@ -768,22 +750,22 @@ LocalOp LocalComputationBuilder::Tuple(absl::Span elements) { return xla::Tuple(&builder_, xla_ops); } -LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, - int64 index) { +LocalOp ComputationBuilder::GetTupleElement(const LocalOp& tuple_data, + int64 index) { return xla::GetTupleElement(tuple_data.op(), index); } -LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { +LocalOp ComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { return xla::Dot(lhs.op(), rhs.op()); } -LocalOp LocalComputationBuilder::DotGeneral( +LocalOp ComputationBuilder::DotGeneral( const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } -LocalOp LocalComputationBuilder::ConvGeneralDilated( +LocalOp ComputationBuilder::ConvGeneralDilated( const LocalOp& lhs, const LocalOp& rhs, absl::Span window_strides, absl::Span> padding, @@ -795,18 +777,18 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated( feature_group_count); } -LocalOp LocalComputationBuilder::ConvertElementType( - const LocalOp& operand, PrimitiveType new_element_type) { +LocalOp ComputationBuilder::ConvertElementType(const LocalOp& operand, + PrimitiveType new_element_type) { return xla::ConvertElementType(operand.op(), new_element_type); } -LocalOp LocalComputationBuilder::BitcastConvertType( - const LocalOp& operand, PrimitiveType new_element_type) { +LocalOp ComputationBuilder::BitcastConvertType(const LocalOp& operand, + PrimitiveType new_element_type) { return xla::BitcastConvertType(operand.op(), new_element_type); } -LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, - absl::Span operands) { +LocalOp ComputationBuilder::Call(const Computation& local_computation, + absl::Span operands) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -815,7 +797,7 @@ LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, return xla::Call(&builder_, local_computation.computation(), xla_ops); } -LocalOp LocalComputationBuilder::CustomCall( +LocalOp ComputationBuilder::CustomCall( const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, const std::vector& operand_shapes_with_layout, @@ -830,19 +812,19 @@ LocalOp LocalComputationBuilder::CustomCall( operand_shapes_with_layout, opaque); } -LocalOp LocalComputationBuilder::Transpose( - const LocalOp& operand, absl::Span permutation) { +LocalOp ComputationBuilder::Transpose(const LocalOp& operand, + absl::Span permutation) { return xla::Transpose(operand.op(), permutation); } -LocalOp LocalComputationBuilder::Rev(const LocalOp& operand, - absl::Span dimensions) { +LocalOp ComputationBuilder::Rev(const LocalOp& operand, + absl::Span dimensions) { return xla::Rev(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::Map(absl::Span operands, - const LocalComputation& local_computation, - absl::Span dimensions) { +LocalOp ComputationBuilder::Map(absl::Span operands, + const Computation& local_computation, + absl::Span dimensions) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -853,17 +835,17 @@ LocalOp LocalComputationBuilder::Map(absl::Span operands, dimensions); } -LocalOp LocalComputationBuilder::Reduce( +LocalOp ComputationBuilder::Reduce( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions_to_reduce) { return xla::Reduce(operand.op(), init_value.op(), local_computation.computation(), dimensions_to_reduce); } -LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( +LocalOp ComputationBuilder::ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, @@ -875,51 +857,50 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( padding); } -LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, - const LocalOp& sigma, - const Shape& shape) { +LocalOp ComputationBuilder::RngNormal(const LocalOp& mu, const LocalOp& sigma, + const Shape& shape) { return xla::RngNormal(mu.op(), sigma.op(), shape); } -LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, - const Shape& shape) { +LocalOp ComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, + const Shape& shape) { return xla::RngUniform(a.op(), b.op(), shape); } -LocalOp LocalComputationBuilder::While(const LocalComputation& condition, - const LocalComputation& body, - const LocalOp& init) { +LocalOp ComputationBuilder::While(const Computation& condition, + const Computation& body, + const LocalOp& init) { return xla::While(condition.computation(), body.computation(), init.op()); } -LocalOp LocalComputationBuilder::Conditional( - const LocalOp& predicate, const LocalOp& true_operand, - const LocalComputation& true_computation, const LocalOp& false_operand, - const LocalComputation& false_computation) { +LocalOp ComputationBuilder::Conditional(const LocalOp& predicate, + const LocalOp& true_operand, + const Computation& true_computation, + const LocalOp& false_operand, + const Computation& false_computation) { return xla::Conditional(predicate.op(), true_operand.op(), true_computation.computation(), false_operand.op(), false_computation.computation()); } -StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { +StatusOr ComputationBuilder::IsConstant(const LocalOp& operand) { return builder_.IsConstant(operand.op()); } -LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { +LocalOp ComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { return xla::Sort(operand.op(), {}, dimension); } -LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, - const LocalOp& values, - int64 dimension) { +LocalOp ComputationBuilder::SortKeyVal(const LocalOp& keys, + const LocalOp& values, int64 dimension) { return xla::Sort(keys.op(), {values.op()}, dimension); } -LocalOp LocalComputationBuilder::Cholesky(const LocalOp& a) { +LocalOp ComputationBuilder::Cholesky(const LocalOp& a) { return xla::Cholesky(a.op()); } -LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) { +LocalOp ComputationBuilder::QR(const LocalOp& a, bool full_matrices) { XlaBuilder* builder = a.op().builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices)); @@ -927,17 +908,16 @@ LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) { }); } -LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a, - const LocalOp& b, - bool left_side, bool lower, - bool unit_diagonal, - int transpose_a) { +LocalOp ComputationBuilder::TriangularSolve(const LocalOp& a, const LocalOp& b, + bool left_side, bool lower, + bool unit_diagonal, + int transpose_a) { return xla::TriangularSolve( a.op(), b.op(), left_side, lower, unit_diagonal, xla::TriangularSolveOptions::Transpose(transpose_a)); } -LocalOp LocalComputationBuilder::Gather( +LocalOp ComputationBuilder::Gather( const LocalOp& input, const LocalOp& start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes) { @@ -945,24 +925,24 @@ LocalOp LocalComputationBuilder::Gather( slice_sizes); } -LocalOp LocalComputationBuilder::Scatter( +LocalOp ComputationBuilder::Scatter( const LocalOp& input, const LocalOp& scatter_indices, - const LocalOp& updates, const LocalComputation& update_computation, + const LocalOp& updates, const Computation& update_computation, const ScatterDimensionNumbers& dimension_numbers) { return xla::Scatter(input.op(), scatter_indices.op(), updates.op(), update_computation.computation(), dimension_numbers); } -StatusOr LocalComputationBuilder::BuildConstantSubGraph( +StatusOr ComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.BuildConstantSubGraph(operand.op())); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -#define _FORWARD(method_name, return_sig, args_sig, args) \ - return_sig LocalComputationBuilder::method_name args_sig { \ - return xla::method_name args; \ +#define _FORWARD(method_name, return_sig, args_sig, args) \ + return_sig ComputationBuilder::method_name args_sig { \ + return xla::method_name args; \ } #define _FORWARD_UNOP(method_name) \ @@ -1051,64 +1031,11 @@ void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) { void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; } -void DeleteCompiledLocalComputation(CompiledLocalComputation* computation) { - delete computation; -} +void DeleteLocalExecutable(LocalExecutable* computation) { delete computation; } -void DeleteCompiledXrtComputation(CompiledXrtComputation* computation) { - delete computation; -} +void DeleteXrtExecutable(XrtExecutable* computation) { delete computation; } -void DeleteLocalComputation(LocalComputation* computation) { - delete computation; -} - -StatusOr DestructureLocalShapedBufferTuple( - LocalShapedBuffer* local_shaped_buffer) { - const Shape tuple_shape = local_shaped_buffer->shape(); - - if (!tuple_shape.IsTuple()) { - return InvalidArgument( - "Attemped to destructure a LocalShapedBuffer that did not have a tuple " - "shape; shape: %s", - ShapeUtil::HumanString(tuple_shape)); - } - - DeviceMemoryAllocator* allocator = - local_shaped_buffer->shaped_buffer()->memory_allocator(); - ShapedBuffer tuple_buffer = local_shaped_buffer->Release(); - - // Extract some metadata we use to construct scoped buffers. - const se::Platform* platform = tuple_buffer.platform(); - int device_ordinal = tuple_buffer.device_ordinal(); - - ShapeTree& shape_tree = tuple_buffer.buffers(); - std::vector results; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { - // Create a shaped buffer for this destructured tuple element. - const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); - VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; - ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); - - ShapeUtil::ForEachSubshape( - subshape, [&](const Shape& s, const ShapeIndex& index) { - ShapeIndex original(index); - original.push_front(i); - se::DeviceMemoryBase* device_memory = - shape_tree.mutable_element(original); - shaped_buffer.set_buffer(*device_memory, index); - *device_memory = se::DeviceMemoryBase(); - }); - - VLOG(3) << "Completed tuple element: " << i; - results.push_back(new LocalShapedBuffer( - ScopedShapedBuffer(std::move(shaped_buffer), allocator))); - } - // Deallocate the root buffer. - se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); - TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); - return new LocalShapedBufferTuple(std::move(results)); -} +void DeleteComputation(Computation* computation) { delete computation; } StatusOr DestructureXrtAllocationTuple( XrtAllocation* allocation, const string& session_target) { diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index bc8b7e610c0..9ff46d57dc6 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -35,42 +35,42 @@ limitations under the License. namespace xla { namespace swig { -// Initializes the number of replicas that XLA will be initialized with (when -// first obtaining a handle to the local XLA service). If this is called after -// the handle to the local XLA service has been established, then an error is -// returned. -Status InitializeReplicaCount(int replica_count); - -// Initializes the platform name that XLA will be initialized with (when -// first obtaining a handle to the local XLA service). If this is called after -// the handle to the local XLA service has been established, then an error is -// returned. -Status InitializePlatformName(const string& platform_name); - -// Returns the replica count that is currently set, regardless of whether the -// local XLA service has been instantiated yet or not. -int GetReplicaCount(); - // Registers a 'fn_capsule' as a CPU custom call target. // 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name // "xla._CPU_CUSTOM_CALL_TARGET". Status RegisterCpuCustomCallTarget(const string& name, PyObject* fn_capsule); -// Wraps the local client's infeed-transfer function. -// -// The default device ordinal (0) is used. -Status TransferToInfeedLocal(const Literal& literal); +// Wrapper around an xla::LocalClient. +class LocalClient { + public: + // Initializes a local XLA client for `platform_name`. Returns an error if no + /// such platform exists, or if the platform has no visible devices. + static StatusOr Get(const string& platform_name); -// Transfers the given literal to the infeed of the given replica. -// -// The replica number is resolved to an appropriate device ordinal. -Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number); + // Copyable and moveable; the class is just a wrapper around a + // xla::LocalClient pointer for convenient SWIG wrapping. -// Transfers a literal of the given shape from the outfeed of the given replica. -// -// The replica number is resolved to an appropriate device ordinal. -StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, - int replica_number); + // Returns the number of devices known to the XLA client. + int DeviceCount() const; + + // Wraps the local client's infeed-transfer function. + // + // The default device ordinal (0) is used. + Status TransferToInfeed(const Literal& literal, int device_ordinal); + + // Transfers a literal of the given shape from the outfeed of the given + // replica. + StatusOr TransferFromOutfeed(const Shape& shape, int device_ordinal); + + xla::LocalClient* client() const { return client_; } + + private: + LocalClient(xla::LocalClient* client); + + xla::LocalClient* client_; +}; + +class LocalShapedBufferTuple; // Represents a reference to literals that live in a device-allocated buffer via // XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a @@ -79,9 +79,9 @@ class LocalShapedBuffer { public: static StatusOr FromLiteral( const Literal& argument, const absl::optional& shape_with_layout, - int replica_number); + const LocalClient& client, int device_ordinal); - LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); + LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, xla::LocalClient* client); StatusOr ToLiteral() const; const Shape& shape() const; const ScopedShapedBuffer* shaped_buffer() const; @@ -90,8 +90,13 @@ class LocalShapedBuffer { // analogous to std::unique_ptr::release(). ShapedBuffer Release(); + // Destructures a tuple-valued LocalShapedBuffer into its constitutent + // elements in LocalShapedBufferTuple form. + StatusOr DestructureTuple(); + private: ScopedShapedBuffer shaped_buffer_; + xla::LocalClient* client_; }; // Result of a tuple destructuring operation on a LocalShapedBuffer -- this @@ -117,11 +122,6 @@ class LocalShapedBufferTuple { std::vector elements_; }; -// Destructures a tuple-valued LocalShapedBuffer into its constitutent elements -// in LocalShapedBufferTuple form. -StatusOr DestructureLocalShapedBufferTuple( - LocalShapedBuffer* local_shaped_buffer); - // Represents a reference to literals that live in a device-allocated buffer via // XRT. Specifically, wraps an int64 handle produced by running the allocation // graph, and an XLA shape to track the referent's shape. @@ -176,14 +176,19 @@ StatusOr DestructureXrtAllocationTuple( // Represents a compiled computation that can be executed given handles to // device-allocated literals. Specifically, wraps an XLA LocalExecutable. -class CompiledLocalComputation { +class LocalExecutable { public: - CompiledLocalComputation(std::unique_ptr executable); + LocalExecutable(std::unique_ptr executable, + xla::DeviceAssignment device_assignment, + xla::LocalClient* client); int num_replicas() const { return executable_->build_options().num_replicas(); } + // Returns the device ordinals to which each replica is assigned. + std::vector DeviceOrdinals() const; + StatusOr Execute( absl::Span argument_handles); @@ -194,18 +199,22 @@ class CompiledLocalComputation { absl::Span > argument_handles); private: - std::unique_ptr executable_; + const std::unique_ptr executable_; + const xla::DeviceAssignment device_assignment_; + xla::LocalClient* const client_; }; // Represents a compiled computation that can be executed given handles to // device-allocated literals. Specifically, wraps an XRT computation handle. -class CompiledXrtComputation { +class XrtExecutable { public: // Accepts a `session_target` argument, used in constructing the // `tensorflow::ClientSession` instance in which the execution graph is run. - CompiledXrtComputation(const ProgramShape& program_shape, int64 handle, - const string& session_target); - ~CompiledXrtComputation(); + XrtExecutable(const ProgramShape& program_shape, int64 handle, + const string& session_target); + ~XrtExecutable(); + + std::vector DeviceOrdinals() const { return {0}; } StatusOr Execute( absl::Span argument_handles); @@ -219,21 +228,21 @@ class CompiledXrtComputation { const string session_target_; }; -// Wraps a XlaComputation produced by a LocalComputationBuilder. The +// Wraps a XlaComputation produced by a ComputationBuilder. The // Compile method compiles the computation to a (local) executable via // the client library's local client. This class is intended to be // made available to Python via SWIG. -class LocalComputation { +class Computation { public: - LocalComputation(XlaComputation computation); + Computation(XlaComputation computation); - StatusOr Compile( + StatusOr Compile( const std::vector& argument_shapes, - const ExecutableBuildOptions* build_options); + const ExecutableBuildOptions* build_options, const LocalClient& client); // Accepts a `session_target` argument, used in constructing the // `tensorflow::ClientSession` instance in which the compilation graph is run. - StatusOr CompileForXrt( + StatusOr CompileForXrt( const std::vector& argument_shapes, const string& session_target); const XlaComputation& computation() const; @@ -243,6 +252,9 @@ class LocalComputation { // string on failure. string GetSerializedProto() const; + // Returns the program shape for this computation. + StatusOr GetProgramShape() const; + // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; @@ -250,7 +262,7 @@ class LocalComputation { XlaComputation computation_; }; -// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended +// Wraps a XlaOp produced by a ComputationBuilder. This class is intended // to be made available to Python via SWIG. class LocalOp { public: @@ -267,20 +279,20 @@ class LocalOp { // Python. // - Set up the underlying builder to use the client library's // LocalClient. -// - Wrap Computations in LocalComputations for Python access. -// - Correspondingly unwrap incoming LocalComputations. -class LocalComputationBuilder { +// - Wrap Computations in Computations for Python access. +// - Correspondingly unwrap incoming Computations. +class ComputationBuilder { public: - LocalComputationBuilder(const string& computation_name); + ComputationBuilder(const string& computation_name); void SetOpMetadata(const OpMetadata& metadata); void ClearOpMetadata(); - // Returns an owned LocalComputation to the caller on success. - StatusOr Build(); + // Returns an owned Computation to the caller on success. + StatusOr Build(); - // Returns an owned LocalComputation to the caller on success with given root. - StatusOr BuildWithRoot(const LocalOp& root); + // Returns an owned Computation to the caller on success with given root. + StatusOr BuildWithRoot(const LocalOp& root); LocalOp Parameter(int64 parameter_number, const Shape& shape, const string& name); @@ -339,11 +351,11 @@ class LocalComputationBuilder { LocalOp ConcatInDim(absl::Span operands, int64 dimension); LocalOp SelectAndScatterWithGeneralPadding( - const LocalOp& operand, const LocalComputation& select, + const LocalOp& operand, const Computation& select, absl::Span window_dimensions, absl::Span window_strides, absl::Span > padding, const LocalOp& source, - const LocalOp& init_value, const LocalComputation& scatter); + const LocalOp& init_value, const Computation& scatter); LocalOp Tuple(absl::Span elements); @@ -369,7 +381,7 @@ class LocalComputationBuilder { LocalOp BitcastConvertType(const LocalOp& operand, PrimitiveType new_element_type); - LocalOp Call(const LocalComputation& local_computation, + LocalOp Call(const Computation& local_computation, absl::Span operands); LocalOp CustomCall(const string& call_target_name, @@ -384,16 +396,16 @@ class LocalComputationBuilder { LocalOp Rev(const LocalOp& operand, absl::Span dimensions); LocalOp Map(absl::Span operands, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions); LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions_to_reduce); LocalOp ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, @@ -405,13 +417,13 @@ class LocalComputationBuilder { LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); - LocalOp While(const LocalComputation& condition, const LocalComputation& body, + LocalOp While(const Computation& condition, const Computation& body, const LocalOp& init); LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, - const LocalComputation& true_computation, + const Computation& true_computation, const LocalOp& false_operand, - const LocalComputation& false_computation); + const Computation& false_computation); StatusOr IsConstant(const LocalOp& operand); @@ -435,11 +447,10 @@ class LocalComputationBuilder { absl::Span slice_sizes); LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices, - const LocalOp& updates, - const LocalComputation& update_computation, + const LocalOp& updates, const Computation& update_computation, const ScatterDimensionNumbers& dimension_numbers); - StatusOr BuildConstantSubGraph(const LocalOp& operand); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; @@ -529,9 +540,9 @@ class LocalComputationBuilder { // Functions for freeing resources from the Python side. void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer); void DeleteXrtAllocation(XrtAllocation* allocation); -void DeleteCompiledLocalComputation(CompiledLocalComputation* computation); -void DeleteCompiledXrtComputation(CompiledXrtComputation* computation); -void DeleteLocalComputation(LocalComputation* computation); +void DeleteLocalExecutable(LocalExecutable* computation); +void DeleteXrtExecutable(XrtExecutable* computation); +void DeleteComputation(Computation* computation); } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index df2ab0b539b..5327ce91dbe 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -23,11 +23,13 @@ limitations under the License. // C++ Python // -------------------------------------+--------------------------------------- // Span <- sequence of int +// vector -> sequence of int // Span <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) // <- object duck-typed as xla_client.Shape +// ProgramShape -> pair of ([arg_shapes], ret_shape) // std::vector <- sequence of xla_client.Shape objects // PrimitiveType <- int // Span> <- sequence of int pairs @@ -97,7 +99,7 @@ limitations under the License. // wrapped in a Python class (xla_client.Shape) so as not to expose // the raw pair externally. // -// Other SWIG object wrappers (e.g. of LocalComputation) are further +// Other SWIG object wrappers (e.g. of Computation) are further // wrapped by xla_client in order to set up a custom destructor that // triggers memory deallocation on the C++ side. @@ -214,6 +216,15 @@ tensorflow::ImportNumpy(); // Basic types + +%typemap(out) std::vector { + PyObject* out = PyList_New($1.size()); + for (int i = 0; i < $1.size(); ++i) { + PyList_SET_ITEM(out, i, PyInt_FromLong($1[i])); + } + $result = out; +} + %typemap(out) StatusOr { if ($1.ok()) { $result = PyBool_FromLong($1.ConsumeValueOrDie()); @@ -287,12 +298,12 @@ tensorflow::ImportNumpy(); // Computation and buffer/allocation types -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { - auto* value = $1.ValueOrDie(); + xla::swig::LocalClient value = $1.ValueOrDie(); { - auto* $1 = value; - $typemap(out, xla::swig::CompiledLocalComputation*) + auto $1 = value; + $typemap(out, xla::swig::LocalClient) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -300,12 +311,25 @@ tensorflow::ImportNumpy(); } } -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::CompiledXrtComputation*) + $typemap(out, xla::swig::LocalExecutable*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::XrtExecutable*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -365,12 +389,12 @@ tensorflow::ImportNumpy(); } } -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::LocalComputation*) + $typemap(out, xla::swig::Computation*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -519,18 +543,30 @@ tensorflow::ImportNumpy(); // Shape %typemap(out) const Shape& { - $result = numpy::PyShapeInfoFromXlaShape(*$1); + $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); } %typemap(out) StatusOr { if ($1.ok()) { - $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); + $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()).release(); } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); SWIG_fail; } } + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = numpy::PyProgramShapeInfoFromXlaProgramShape( + $1.ConsumeValueOrDie()).release(); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + %typemap(in) const Shape& (Shape temp) { StatusOr statusor = numpy::XlaShapeFromPyShape($input); if (!statusor.ok()) { @@ -558,7 +594,7 @@ tensorflow::ImportNumpy(); } %typemap(out) std::unique_ptr { - $result = numpy::PyShapeInfoFromXlaShape(*$1); + $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); } %typemap(in) const std::vector& (std::vector temps) { @@ -966,17 +1002,17 @@ tensorflow::ImportNumpy(); %ignoreall %unignore xla; %unignore xla::swig; -%unignore xla::swig::InitializeReplicaCount; -%unignore xla::swig::InitializePlatformName; -%unignore xla::swig::GetReplicaCount; %unignore xla::swig::RegisterCpuCustomCallTarget; -%unignore xla::swig::TransferToInfeedLocal; -%unignore xla::swig::TransferToInfeedLocalReplica; -%unignore xla::swig::TransferFromOutfeedLocalReplica; +%unignore xla::swig::LocalClient; +%unignore xla::swig::LocalClient::Get; +%unignore xla::swig::LocalClient::DeviceCount; +%unignore xla::swig::LocalClient::TransferToInfeed; +%unignore xla::swig::LocalClient::TransferFromOutfeed; %unignore xla::swig::LocalShapedBuffer; %unignore xla::swig::LocalShapedBuffer::FromLiteral; %unignore xla::swig::LocalShapedBuffer::ToLiteral; %unignore xla::swig::LocalShapedBuffer::shape; +%unignore xla::swig::LocalShapedBuffer::DestructureTuple; %unignore xla::swig::LocalShapedBufferTuple; %unignore xla::swig::LocalShapedBufferTuple::Release; %unignore xla::swig::LocalShapedBufferTuple::size; @@ -987,139 +1023,141 @@ tensorflow::ImportNumpy(); %unignore xla::swig::XrtAllocationTuple; %unignore xla::swig::XrtAllocationTuple::Release; %unignore xla::swig::XrtAllocationTuple::size; -%unignore xla::swig::CompiledLocalComputation; -%unignore xla::swig::CompiledLocalComputation::Execute; -%unignore xla::swig::CompiledLocalComputation::ExecutePerReplica; -%unignore xla::swig::CompiledXrtComputation; -%unignore xla::swig::CompiledXrtComputation::Execute; -%unignore xla::swig::LocalComputation; -%unignore xla::swig::LocalComputation::Compile; -%unignore xla::swig::LocalComputation::CompileForXrt; -%unignore xla::swig::LocalComputation::GetReturnValueShape; -%unignore xla::swig::LocalComputation::GetSerializedProto; +%unignore xla::swig::LocalExecutable; +%unignore xla::swig::LocalExecutable::DeviceOrdinals; +%unignore xla::swig::LocalExecutable::Execute; +%unignore xla::swig::LocalExecutable::ExecutePerReplica; +%unignore xla::swig::XrtExecutable; +%unignore xla::swig::XrtExecutable::DeviceOrdinals; +%unignore xla::swig::XrtExecutable::Execute; +%unignore xla::swig::Computation; +%unignore xla::swig::Computation::Compile; +%unignore xla::swig::Computation::CompileForXrt; +%unignore xla::swig::Computation::GetProgramShape; +%unignore xla::swig::Computation::GetReturnValueShape; +%unignore xla::swig::Computation::GetSerializedProto; %unignore xla::swig::LocalOp; -%unignore xla::swig::LocalComputationBuilder; -%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; -%unignore xla::swig::LocalComputationBuilder::Build; -%unignore xla::swig::LocalComputationBuilder::BuildWithRoot; -%unignore xla::swig::LocalComputationBuilder::SetOpMetadata; -%unignore xla::swig::LocalComputationBuilder::ClearOpMetadata; -%unignore xla::swig::LocalComputationBuilder::Parameter; -%unignore xla::swig::LocalComputationBuilder::GetShape; -%unignore xla::swig::LocalComputationBuilder::GetReturnValueShape; -%unignore xla::swig::LocalComputationBuilder::Infeed; -%unignore xla::swig::LocalComputationBuilder::Outfeed; -%unignore xla::swig::LocalComputationBuilder::ConstantLiteral; -%unignore xla::swig::LocalComputationBuilder::ConstantR0; -%unignore xla::swig::LocalComputationBuilder::Iota; -%unignore xla::swig::LocalComputationBuilder::BroadcastedIota; -%unignore xla::swig::LocalComputationBuilder::Broadcast; -%unignore xla::swig::LocalComputationBuilder::BroadcastInDim; -%unignore xla::swig::LocalComputationBuilder::Pad; -%unignore xla::swig::LocalComputationBuilder::Reshape; -%unignore xla::swig::LocalComputationBuilder::Collapse; -%unignore xla::swig::LocalComputationBuilder::AllToAll; -%unignore xla::swig::LocalComputationBuilder::CrossReplicaSum; -%unignore xla::swig::LocalComputationBuilder::Slice; -%unignore xla::swig::LocalComputationBuilder::SliceInDim; -%unignore xla::swig::LocalComputationBuilder::DynamicSlice; -%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice; -%unignore xla::swig::LocalComputationBuilder::ConcatInDim; -%unignore xla::swig::LocalComputationBuilder::SelectAndScatterWithGeneralPadding; -%unignore xla::swig::LocalComputationBuilder::Select; -%unignore xla::swig::LocalComputationBuilder::Tuple; -%unignore xla::swig::LocalComputationBuilder::GetTupleElement; -%unignore xla::swig::LocalComputationBuilder::ConvertElementType; -%unignore xla::swig::LocalComputationBuilder::BitcastConvertType; -%unignore xla::swig::LocalComputationBuilder::Call; -%unignore xla::swig::LocalComputationBuilder::Transpose; -%unignore xla::swig::LocalComputationBuilder::Rev; -%unignore xla::swig::LocalComputationBuilder::Clamp; -%unignore xla::swig::LocalComputationBuilder::Map; -%unignore xla::swig::LocalComputationBuilder::Reduce; -%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding; -%unignore xla::swig::LocalComputationBuilder::RngNormal; -%unignore xla::swig::LocalComputationBuilder::RngUniform; -%unignore xla::swig::LocalComputationBuilder::RngBernoulli; -%unignore xla::swig::LocalComputationBuilder::While; -%unignore xla::swig::LocalComputationBuilder::Conditional; -%unignore xla::swig::LocalComputationBuilder::IsConstant; -%unignore xla::swig::LocalComputationBuilder::Eq; -%unignore xla::swig::LocalComputationBuilder::Ne; -%unignore xla::swig::LocalComputationBuilder::Ge; -%unignore xla::swig::LocalComputationBuilder::Gt; -%unignore xla::swig::LocalComputationBuilder::Lt; -%unignore xla::swig::LocalComputationBuilder::Le; -%unignore xla::swig::LocalComputationBuilder::Dot; -%unignore xla::swig::LocalComputationBuilder::DotGeneral; -%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated; -%unignore xla::swig::LocalComputationBuilder::Add; -%unignore xla::swig::LocalComputationBuilder::Sub; -%unignore xla::swig::LocalComputationBuilder::Mul; -%unignore xla::swig::LocalComputationBuilder::Div; -%unignore xla::swig::LocalComputationBuilder::Rem; -%unignore xla::swig::LocalComputationBuilder::Max; -%unignore xla::swig::LocalComputationBuilder::Min; -%unignore xla::swig::LocalComputationBuilder::And; -%unignore xla::swig::LocalComputationBuilder::Or; -%unignore xla::swig::LocalComputationBuilder::Xor; -%unignore xla::swig::LocalComputationBuilder::ShiftLeft; -%unignore xla::swig::LocalComputationBuilder::ShiftRightArithmetic; -%unignore xla::swig::LocalComputationBuilder::ShiftRightLogical; -%unignore xla::swig::LocalComputationBuilder::Not; -%unignore xla::swig::LocalComputationBuilder::Abs; -%unignore xla::swig::LocalComputationBuilder::Exp; -%unignore xla::swig::LocalComputationBuilder::Expm1; -%unignore xla::swig::LocalComputationBuilder::Floor; -%unignore xla::swig::LocalComputationBuilder::Ceil; -%unignore xla::swig::LocalComputationBuilder::Round; -%unignore xla::swig::LocalComputationBuilder::Log; -%unignore xla::swig::LocalComputationBuilder::Log1p; -%unignore xla::swig::LocalComputationBuilder::Sign; -%unignore xla::swig::LocalComputationBuilder::Cos; -%unignore xla::swig::LocalComputationBuilder::Sin; -%unignore xla::swig::LocalComputationBuilder::Tanh; -%unignore xla::swig::LocalComputationBuilder::Atan2; -%unignore xla::swig::LocalComputationBuilder::IsFinite; -%unignore xla::swig::LocalComputationBuilder::Pow; -%unignore xla::swig::LocalComputationBuilder::Neg; -%unignore xla::swig::LocalComputationBuilder::Sort; -%unignore xla::swig::LocalComputationBuilder::SortKeyVal; -%unignore xla::swig::LocalComputationBuilder::Sqrt; -%unignore xla::swig::LocalComputationBuilder::Rsqrt; -%unignore xla::swig::LocalComputationBuilder::Square; -%unignore xla::swig::LocalComputationBuilder::Reciprocal; -%unignore xla::swig::LocalComputationBuilder::Erfc; -%unignore xla::swig::LocalComputationBuilder::Erf; -%unignore xla::swig::LocalComputationBuilder::ErfInv; -%unignore xla::swig::LocalComputationBuilder::Lgamma; -%unignore xla::swig::LocalComputationBuilder::Digamma; -%unignore xla::swig::LocalComputationBuilder::Acos; -%unignore xla::swig::LocalComputationBuilder::Asin; -%unignore xla::swig::LocalComputationBuilder::Atan; -%unignore xla::swig::LocalComputationBuilder::Tan; -%unignore xla::swig::LocalComputationBuilder::Acosh; -%unignore xla::swig::LocalComputationBuilder::Asinh; -%unignore xla::swig::LocalComputationBuilder::Atanh; -%unignore xla::swig::LocalComputationBuilder::Cosh; -%unignore xla::swig::LocalComputationBuilder::Sinh; -%unignore xla::swig::LocalComputationBuilder::Real; -%unignore xla::swig::LocalComputationBuilder::Imag; -%unignore xla::swig::LocalComputationBuilder::Conj; -%unignore xla::swig::LocalComputationBuilder::Complex; -%unignore xla::swig::LocalComputationBuilder::Cholesky; -%unignore xla::swig::LocalComputationBuilder::QR; -%unignore xla::swig::LocalComputationBuilder::TriangularSolve; -%unignore xla::swig::LocalComputationBuilder::CustomCall; -%unignore xla::swig::LocalComputationBuilder::Gather; -%unignore xla::swig::LocalComputationBuilder::Scatter; -%unignore xla::swig::DeleteLocalComputation; -%unignore xla::swig::DestructureLocalShapedBufferTuple; +%unignore xla::swig::ComputationBuilder; +%unignore xla::swig::ComputationBuilder::ComputationBuilder; +%unignore xla::swig::ComputationBuilder::Build; +%unignore xla::swig::ComputationBuilder::BuildWithRoot; +%unignore xla::swig::ComputationBuilder::SetOpMetadata; +%unignore xla::swig::ComputationBuilder::ClearOpMetadata; +%unignore xla::swig::ComputationBuilder::Parameter; +%unignore xla::swig::ComputationBuilder::GetShape; +%unignore xla::swig::ComputationBuilder::GetReturnValueShape; +%unignore xla::swig::ComputationBuilder::Infeed; +%unignore xla::swig::ComputationBuilder::Outfeed; +%unignore xla::swig::ComputationBuilder::ConstantLiteral; +%unignore xla::swig::ComputationBuilder::ConstantR0; +%unignore xla::swig::ComputationBuilder::Iota; +%unignore xla::swig::ComputationBuilder::BroadcastedIota; +%unignore xla::swig::ComputationBuilder::Broadcast; +%unignore xla::swig::ComputationBuilder::BroadcastInDim; +%unignore xla::swig::ComputationBuilder::Pad; +%unignore xla::swig::ComputationBuilder::Reshape; +%unignore xla::swig::ComputationBuilder::Collapse; +%unignore xla::swig::ComputationBuilder::AllToAll; +%unignore xla::swig::ComputationBuilder::CrossReplicaSum; +%unignore xla::swig::ComputationBuilder::Slice; +%unignore xla::swig::ComputationBuilder::SliceInDim; +%unignore xla::swig::ComputationBuilder::DynamicSlice; +%unignore xla::swig::ComputationBuilder::DynamicUpdateSlice; +%unignore xla::swig::ComputationBuilder::ConcatInDim; +%unignore xla::swig::ComputationBuilder::SelectAndScatterWithGeneralPadding; +%unignore xla::swig::ComputationBuilder::Select; +%unignore xla::swig::ComputationBuilder::Tuple; +%unignore xla::swig::ComputationBuilder::GetTupleElement; +%unignore xla::swig::ComputationBuilder::ConvertElementType; +%unignore xla::swig::ComputationBuilder::BitcastConvertType; +%unignore xla::swig::ComputationBuilder::Call; +%unignore xla::swig::ComputationBuilder::Transpose; +%unignore xla::swig::ComputationBuilder::Rev; +%unignore xla::swig::ComputationBuilder::Clamp; +%unignore xla::swig::ComputationBuilder::Map; +%unignore xla::swig::ComputationBuilder::Reduce; +%unignore xla::swig::ComputationBuilder::ReduceWindowWithGeneralPadding; +%unignore xla::swig::ComputationBuilder::RngNormal; +%unignore xla::swig::ComputationBuilder::RngUniform; +%unignore xla::swig::ComputationBuilder::RngBernoulli; +%unignore xla::swig::ComputationBuilder::While; +%unignore xla::swig::ComputationBuilder::Conditional; +%unignore xla::swig::ComputationBuilder::IsConstant; +%unignore xla::swig::ComputationBuilder::Eq; +%unignore xla::swig::ComputationBuilder::Ne; +%unignore xla::swig::ComputationBuilder::Ge; +%unignore xla::swig::ComputationBuilder::Gt; +%unignore xla::swig::ComputationBuilder::Lt; +%unignore xla::swig::ComputationBuilder::Le; +%unignore xla::swig::ComputationBuilder::Dot; +%unignore xla::swig::ComputationBuilder::DotGeneral; +%unignore xla::swig::ComputationBuilder::ConvGeneralDilated; +%unignore xla::swig::ComputationBuilder::Add; +%unignore xla::swig::ComputationBuilder::Sub; +%unignore xla::swig::ComputationBuilder::Mul; +%unignore xla::swig::ComputationBuilder::Div; +%unignore xla::swig::ComputationBuilder::Rem; +%unignore xla::swig::ComputationBuilder::Max; +%unignore xla::swig::ComputationBuilder::Min; +%unignore xla::swig::ComputationBuilder::And; +%unignore xla::swig::ComputationBuilder::Or; +%unignore xla::swig::ComputationBuilder::Xor; +%unignore xla::swig::ComputationBuilder::ShiftLeft; +%unignore xla::swig::ComputationBuilder::ShiftRightArithmetic; +%unignore xla::swig::ComputationBuilder::ShiftRightLogical; +%unignore xla::swig::ComputationBuilder::Not; +%unignore xla::swig::ComputationBuilder::Abs; +%unignore xla::swig::ComputationBuilder::Exp; +%unignore xla::swig::ComputationBuilder::Expm1; +%unignore xla::swig::ComputationBuilder::Floor; +%unignore xla::swig::ComputationBuilder::Ceil; +%unignore xla::swig::ComputationBuilder::Round; +%unignore xla::swig::ComputationBuilder::Log; +%unignore xla::swig::ComputationBuilder::Log1p; +%unignore xla::swig::ComputationBuilder::Sign; +%unignore xla::swig::ComputationBuilder::Cos; +%unignore xla::swig::ComputationBuilder::Sin; +%unignore xla::swig::ComputationBuilder::Tanh; +%unignore xla::swig::ComputationBuilder::Atan2; +%unignore xla::swig::ComputationBuilder::IsFinite; +%unignore xla::swig::ComputationBuilder::Pow; +%unignore xla::swig::ComputationBuilder::Neg; +%unignore xla::swig::ComputationBuilder::Sort; +%unignore xla::swig::ComputationBuilder::SortKeyVal; +%unignore xla::swig::ComputationBuilder::Sqrt; +%unignore xla::swig::ComputationBuilder::Rsqrt; +%unignore xla::swig::ComputationBuilder::Square; +%unignore xla::swig::ComputationBuilder::Reciprocal; +%unignore xla::swig::ComputationBuilder::Erfc; +%unignore xla::swig::ComputationBuilder::Erf; +%unignore xla::swig::ComputationBuilder::ErfInv; +%unignore xla::swig::ComputationBuilder::Lgamma; +%unignore xla::swig::ComputationBuilder::Digamma; +%unignore xla::swig::ComputationBuilder::Acos; +%unignore xla::swig::ComputationBuilder::Asin; +%unignore xla::swig::ComputationBuilder::Atan; +%unignore xla::swig::ComputationBuilder::Tan; +%unignore xla::swig::ComputationBuilder::Acosh; +%unignore xla::swig::ComputationBuilder::Asinh; +%unignore xla::swig::ComputationBuilder::Atanh; +%unignore xla::swig::ComputationBuilder::Cosh; +%unignore xla::swig::ComputationBuilder::Sinh; +%unignore xla::swig::ComputationBuilder::Real; +%unignore xla::swig::ComputationBuilder::Imag; +%unignore xla::swig::ComputationBuilder::Conj; +%unignore xla::swig::ComputationBuilder::Complex; +%unignore xla::swig::ComputationBuilder::Cholesky; +%unignore xla::swig::ComputationBuilder::QR; +%unignore xla::swig::ComputationBuilder::TriangularSolve; +%unignore xla::swig::ComputationBuilder::CustomCall; +%unignore xla::swig::ComputationBuilder::Gather; +%unignore xla::swig::ComputationBuilder::Scatter; +%unignore xla::swig::DeleteComputation; %unignore xla::swig::DestructureXrtAllocationTuple; %unignore xla::swig::DeleteLocalShapedBuffer; %unignore xla::swig::DeleteXrtAllocation; -%unignore xla::swig::DeleteCompiledLocalComputation; -%unignore xla::swig::DeleteCompiledXrtComputation; +%unignore xla::swig::DeleteLocalExecutable; +%unignore xla::swig::DeleteXrtExecutable; %thread; %include "tensorflow/compiler/xla/python/local_computation_builder.h" diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 8e056f97255..aa692c78655 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -127,28 +127,42 @@ bool NumpyTypeIsValid(int np_type) { } } -PyObject* PyShapeInfoFromXlaShape(const Shape& shape) { +Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape) { int np_typenum = PrimitiveTypeToNumpyType(shape.element_type()); PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum); - PyObject* dimensions; + Safe_PyObjectPtr dimensions; if (shape.IsTuple()) { int num_elements = ShapeUtil::TupleElementCount(shape); - dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape)); + dimensions = make_safe(PyTuple_New(ShapeUtil::TupleElementCount(shape))); for (int i = 0; i < num_elements; ++i) { PyTuple_SET_ITEM( - dimensions, i, - PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))); + dimensions.get(), i, + PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i)) + .release()); } } else { int rank = shape.rank(); - dimensions = PyTuple_New(rank); + dimensions = make_safe(PyTuple_New(rank)); for (int i = 0; i < rank; ++i) { - PyTuple_SET_ITEM(dimensions, i, + PyTuple_SET_ITEM(dimensions.get(), i, LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i))); } } - return PyTuple_Pack(2, np_dtype, dimensions); + return make_safe(PyTuple_Pack(2, np_dtype, dimensions.release())); +} + +Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape( + const ProgramShape& shape) { + Safe_PyObjectPtr arg_shapes = make_safe(PyTuple_New(shape.parameters_size())); + for (int i = 0; i < shape.parameters_size(); ++i) { + PyTuple_SET_ITEM(arg_shapes.get(), i, + PyShapeInfoFromXlaShape(shape.parameters(i)).release()); + } + + Safe_PyObjectPtr result_shape = PyShapeInfoFromXlaShape(shape.result()); + return make_safe( + PyTuple_Pack(2, arg_shapes.release(), result_shape.release())); } // Precondition: o->ob_type == &PyArrayDescr_Type diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 737fc4b29c1..89861fc4f01 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -64,7 +64,13 @@ bool NumpyTypeIsValid(int np_type); // providing the array dimensions. // // The return value is a new reference. -PyObject* PyShapeInfoFromXlaShape(const Shape& shape); +Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape); + +// Returns a pair of (arg_shapes, result_shape), where arg_shapes is a tuple +// of argument shapes and result_shape is the result shape. Each shape is as +// described in in PyShapeInfoFromXlaShape's comment. +Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape( + const ProgramShape& shape); // Converts a Python object with a method interface mathing that of // xla_client.Shape into an XLA Shape object. diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index c7afed300b6..020cc587fe7 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -36,7 +36,7 @@ from tensorflow.compiler.xla.service import hlo_pb2 # Most functions are snake_case for consistency with other modules, whereas -# method names of ComputationBuilder and LocalComputation are CamelCase for +# method names of ComputationBuilder and Computation are CamelCase for # consistency with XLA. # pylint: disable=invalid-name @@ -50,7 +50,7 @@ from tensorflow.compiler.xla.service import hlo_pb2 # which case we need to be able to detect when incompatible versions are # installed. def version(): - return (0, 1, 7) + return (0, 1, 8) _OP_METADATA_FIELDS = [ @@ -66,6 +66,10 @@ OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS) class Backend(object): """Abstract base class for XLA backends.""" + @abc.abstractmethod + def device_count(self): + """Returns the number of devices known to the backend.""" + @abc.abstractmethod def buffer_from_pyval(self, pyval, device=0): """Allocates a fresh buffer and populates it with `pyval`.""" @@ -95,25 +99,39 @@ class Backend(object): """Runs an executable in a replicated manner.""" +def _maybe_encode_string(s): + if six.PY3: + return s.encode('utf-8') + else: + return s + + class XlaLocalBackend(Backend): """XLA backend implemented using the in-process xla::LocalClient API.""" + def __init__(self, platform=None): + platform = platform or _get_default_platform_name() + self.client = c_api.LocalClient.Get(_maybe_encode_string(platform)) + + def device_count(self): + return self.client.DeviceCount() + def buffer_from_pyval(self, pyval, device=0): - return c_api.LocalShapedBuffer.FromLiteral(pyval, None, device) + return c_api.LocalShapedBuffer.FromLiteral(pyval, None, self.client, device) def delete_buffer(self, c_buffer): c_api.DeleteLocalShapedBuffer(c_buffer) def destructure_tuple(self, c_buffer): - result = c_api.DestructureLocalShapedBufferTuple(c_buffer) + result = c_buffer.DestructureTuple() return [result.Release(i) for i in xrange(result.size())] def compile(self, c_computation, argument_shapes, compile_options): - return c_computation.Compile(argument_shapes, compile_options) + return c_computation.Compile(argument_shapes, compile_options, self.client) def delete_executable(self, executable): - assert isinstance(executable, c_api.CompiledLocalComputation) - c_api.DeleteCompiledLocalComputation(executable) + assert isinstance(executable, c_api.LocalExecutable) + c_api.DeleteLocalExecutable(executable) def execute(self, executable, args): return executable.Execute(args) @@ -130,6 +148,9 @@ class XrtBackend(Backend): def __init__(self, target): self.target = target + def device_count(self): + return 1 # Multidevice execution not implemented. + def buffer_from_pyval(self, pyval, device=0): if device != 0: raise NotImplementedError( @@ -150,8 +171,8 @@ class XrtBackend(Backend): _maybe_encode_string(self.target)) def delete_executable(self, executable): - assert isinstance(executable, c_api.CompiledXrtComputation) - c_api.DeleteCompiledXrtComputation(executable) + assert isinstance(executable, c_api.XrtExecutable) + c_api.DeleteXrtExecutable(executable) def execute(self, executable, args): return executable.Execute(args) @@ -163,7 +184,20 @@ class XrtBackend(Backend): return [executable.Execute(per_replica_args[0])] -XLA_LOCAL_BACKEND = XlaLocalBackend() +_default_platform_name = 'Host' +_default_backend = None + + +def _get_default_platform_name(): + return _default_platform_name + + +def _get_default_local_backend(): + global _default_backend + global _default_platform_name + if _default_backend is None: + _default_backend = XlaLocalBackend(_default_platform_name) + return _default_backend class BackendType(enum.Enum): @@ -174,7 +208,7 @@ class BackendType(enum.Enum): def BackendSpec(backend, target): """Compatibility wrapper to support older clients. Do not use in new code.""" if backend == BackendType.XLA_LOCAL: - return XLA_LOCAL_BACKEND + return _get_default_local_backend() elif backend == BackendType.XRT: return XrtBackend(target) else: @@ -201,13 +235,6 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): source_line=lineno) -def _maybe_encode_string(s): - if six.PY3: - return s.encode('utf-8') - else: - return s - - class PaddingType(enum.Enum): VALID = 1 SAME = 2 @@ -346,22 +373,18 @@ class LocalBuffer(object): means the referent is in device memory. """ - def __init__(self, c_buffer, backend, replica): + def __init__(self, c_buffer, backend, device): self.c_buffer = c_buffer self._backend = backend - self._replica = replica + self._device = device @staticmethod - def from_pyval(pyval, replica=0, backend=XLA_LOCAL_BACKEND): + def from_pyval(pyval, device=0, backend=None): """Allocate and copy to XLA the given python value.""" + backend = backend or _get_default_local_backend() pyval = require_numpy_array_layout(pyval) - num_replicas = get_replica_count() - if not 0 <= replica < num_replicas: - raise ValueError( - 'Attempt to place buffer on replica {} when the replica count is {}' - .format(replica, num_replicas)) - cbuf = backend.buffer_from_pyval(pyval, replica) - return LocalBuffer(cbuf, backend, replica) + cbuf = backend.buffer_from_pyval(pyval, device) + return LocalBuffer(cbuf, backend, device) def to_py(self): return self.c_buffer.ToLiteral() @@ -369,8 +392,8 @@ class LocalBuffer(object): def shape(self): return _wrap_shape(self.c_buffer.shape()) - def replica(self): - return self._replica + def device(self): + return self._device def delete(self): if self.c_buffer is not None: @@ -383,7 +406,7 @@ class LocalBuffer(object): result = self._backend.destructure_tuple(self.c_buffer) self.delete() return tuple( - LocalBuffer(sub_buffer, replica=self._replica, backend=self._backend) + LocalBuffer(sub_buffer, device=self._device, backend=self._backend) for sub_buffer in result) def is_deleted(self): @@ -533,6 +556,16 @@ class Shape(object): updated._check_minor_to_major() # pylint: disable=protected-access return updated + def with_major_to_minor_layout_if_absent(self): + """Returns a copy of a shape with missing layouts set to major-to-minor.""" + + def f(a): + if a.minor_to_major(): + return None + return a.update_minor_to_major(tuple(xrange(a.rank() - 1, -1, -1))) + + return self.map_leaves(f) + def serialize(self, proto): """Serializes 'shape' into proto.""" if self.is_tuple(): @@ -548,6 +581,10 @@ class Shape(object): proto.layout.minor_to_major.extend(self.minor_to_major()) +ProgramShape = collections.namedtuple('ProgramShape', + ('parameter_shapes', 'result_shape')) + + def _wrap_shape(shape_info): dtype, dims = shape_info element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] @@ -581,7 +618,7 @@ class CompileOptions(object): self.num_replicas = get_replica_count() -def transfer_to_infeed(value, replica_number=None): +def transfer_to_infeed(value, device_ordinal=0): """Transfers the given value into the XLA infeed queue. XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with @@ -591,52 +628,49 @@ def transfer_to_infeed(value, replica_number=None): Args: value: the value that the caller would like to enqueue into the XLA infeed queue - replica_number: the replica number to infeed the value to -- if not - provided, then the default replica (trivially replica 0) is used. + device_ordinal: the device to infeed the value to. Each device has a + distinct infeed queue. """ - if replica_number is None: - c_api.TransferToInfeedLocal(require_numpy_array_layout(value)) - else: - c_api.TransferToInfeedLocalReplica( - require_numpy_array_layout(value), replica_number) + # TODO(phawkins): support non-default backends. + backend = _get_default_local_backend() + backend.client.TransferToInfeed( + require_numpy_array_layout(value), device_ordinal) -def transfer_from_outfeed(shape, replica_number=None): - """Transfers a literal of the given shape from replica_number's outfeed. +def transfer_from_outfeed(shape, device_ordinal=0): + """Transfers a literal of the given shape from `device_ordinal`'s outfeed. Args: shape: The shape of the value to transfer from outfeed. - replica_number: The replica number ordinal to transfer the outfeed value - from. (Each replica has a distinct outfeed queue.) + device_ordinal: The device ordinal to transfer the outfeed value from. Each + device has a distinct outfeed queue.. Returns: The literal value that is produced from the outfeed queue. """ - return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0) + # TODO(phawkins): support non-default backends. + backend = _get_default_local_backend() + return backend.client.TransferFromOutfeed(shape, device_ordinal) -class LocalComputation(object): - """Python wrapper for a local XLA Computation. +class Computation(object): + """Python wrapper for an XLA Computation. - A LocalComputation can be executed if it is compiled. Otherwise, it - can still be used as a Computation where required by the - ComputationBuilder methods. + A Computation can be compiled to form an Executable, or used as a + subcomputation in ComputationBuilder methods. """ - def __init__(self, c_computation, is_compiled, backend=XLA_LOCAL_BACKEND): + def __init__(self, c_computation, backend=None): self._c_computation = c_computation + # The backend argument is deprecated. Pass a backend to Compile() instead. self._backend = backend - self._is_compiled = is_compiled @property def computation(self): - if self._is_compiled: - raise ValueError( - 'Attempt to read the XLA computation of a compiled LocalComputation.') return self._c_computation def GetProto(self): - """Get the HloModuleProto proto object in this local computation. + """Get the HloModuleProto proto object in this computation. Returns: An HloModuleProto proto object that has the whole-graph information. @@ -645,30 +679,25 @@ class LocalComputation(object): proto = hlo_pb2.HloModuleProto.FromString(serialized) return proto - def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): - """Compiles an un-compiled local computation. + def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None, + backend=None): + """Compiles a computation. - Local computations are the result of a "LocalComputationBuild'ing" process - -- they start in uncompiled form, and via a call to Compile() turn into a - compiled local computation. - - Raises: - ValueError: if this is already a compiled local computation. + Computations are the result of a "ComputationBuild'ing" process. Arguments: argument_shapes: parameter shapes -- they are first laid out by layout_fn if layout_fn is provided. Otherwise, the default layout for those shapes will be used. - compile_options: options to use for compilation, includes an optional - laid out result shape for the computation. + compile_options: options to use for compilation, includes an optional laid + out result shape for the computation. layout_fn: lambda that is used to lay out the argument/result shapes. + backend: a `Backend` for which an executable should be generated. Returns: - A newly *compiled* local computation instance. + A Executable instance. """ - if self._is_compiled: - raise ValueError('Attempt to compile a compiled local XLA computation.') - + backend = backend or self._backend or _get_default_local_backend() result_shape = _wrap_shape(self.computation.GetReturnValueShape()) if layout_fn: @@ -681,29 +710,55 @@ class LocalComputation(object): compile_options = compile_options or CompileOptions() compile_options.result_shape = result_shape - c = self._backend.compile(self.computation, argument_shapes, - compile_options) - return LocalComputation(c, is_compiled=True, backend=self._backend) + c = backend.compile(self.computation, argument_shapes, compile_options) + return Executable(c, backend=backend) def CompileWithExampleArguments(self, arguments=(), compile_options=None, - layout_fn=None): + layout_fn=None, + backend=None): return self.Compile( argument_shapes=[Shape.from_pyval(arg) for arg in arguments], compile_options=compile_options, - layout_fn=layout_fn) + layout_fn=layout_fn, + backend=backend) + + def GetProgramShape(self): + (arg_shapes, result_shape) = self._c_computation.GetProgramShape() + return ProgramShape([_wrap_shape(arg) for arg in arg_shapes], + _wrap_shape(result_shape)) def GetReturnValueShape(self): return _wrap_shape(self._c_computation.GetReturnValueShape()) + def __del__(self): + # Python may have freed c_api first. + if c_api and self._c_computation: + assert isinstance(self._c_computation, c_api.Computation) + c_api.DeleteComputation(self._c_computation) + + +class Executable(object): + """Python wrapper for an XLA Executable.""" + + def __init__(self, c_executable, backend=None): + self._c_executable = c_executable + self._device_ordinals = c_executable.DeviceOrdinals() + self._backend = backend + + def DeviceOrdinals(self): + """Returns a list containing the device ordinals for each replica.""" + return self._device_ordinals + def Execute(self, arguments=(), check_for_deleted_args=True): """Execute on one replica with LocalBuffer arguments and return value.""" if check_for_deleted_args and any(arg.is_deleted() for arg in arguments): raise ValueError('Executing with deleted local buffer argument') raw_args = [arg.c_buffer for arg in arguments] - output_buffer = self._backend.execute(self._c_computation, raw_args) - return LocalBuffer(output_buffer, backend=self._backend, replica=0) + output_buffer = self._backend.execute(self._c_executable, raw_args) + return LocalBuffer( + output_buffer, backend=self._backend, device=self._device_ordinals[0]) def ExecutePerReplica(self, arguments=None): """Execute on many replicas with LocalBuffer arguments and return value. @@ -713,14 +768,12 @@ class LocalComputation(object): sequence comprises the arguments for execution on the i'th replica. Returns: - A list of the computation's outputs on each replica, as a LocalBuffer. If + A list of the computation's outputs for each replica, as a LocalBuffer. If a shallow sequence of arguments was passed in for `arguments`, then the sole, zero'th replica's output is returned instead, as a LocalBuffer. """ - if not self._is_compiled: - raise ValueError('Cannot execute an uncompiled local XLA computation.') if arguments is None: - arguments = ((),) * get_replica_count() + arguments = ((),) * len(self._device_ordinals) else: arguments = [list(replica_args) for replica_args in arguments] @@ -729,30 +782,35 @@ class LocalComputation(object): for arg in replica_args: if arg.is_deleted(): raise ValueError('Executing with deleted local buffer argument') - if arg.replica() != replica: + if arg.device() != self._device_ordinals[replica]: raise ValueError( - 'Executing on replica {} with argument from replica {}'.format( - replica, arg.replica())) + 'Executing on device {} with argument from device {}'.format( + self._device_ordinals[replica], arg.device())) # Pull out argument buffer handles + # pylint: disable=g-complex-comprehension stripped_args = [ [arg.c_buffer for arg in replica_args] for replica_args in arguments ] # Execute - output_buffers = self._backend.execute_replicated( - self._c_computation, stripped_args) + output_buffers = self._backend.execute_replicated(self._c_executable, + stripped_args) # Wrap output handles in LocalBuffer instances return tuple( - LocalBuffer(output_buffer, backend=self._backend, replica=replica) + LocalBuffer( + output_buffer, + backend=self._backend, + device=self._device_ordinals[replica]) for replica, output_buffer in enumerate(output_buffers)) def ExecuteWithPythonValues(self, arguments=()): """Execute on one replica with Python values as arguments and output.""" def put(arg): - return LocalBuffer.from_pyval(arg, backend=self._backend) + return LocalBuffer.from_pyval( + arg, device=self._device_ordinals[0], backend=self._backend) arguments = [put(arg) for arg in arguments] return self.Execute(arguments).to_py() @@ -760,22 +818,19 @@ class LocalComputation(object): def ExecuteWithPythonValuesPerReplica(self, arguments): """Execute on many replicas with Python values as arguments and output.""" - def put(arg, replica): - return LocalBuffer.from_pyval(arg, replica, backend=self._backend) + def put(arg, device): + return LocalBuffer.from_pyval(arg, device, backend=self._backend) - arguments = [[put(arg, replica) - for arg in replica_args] - for replica, replica_args in enumerate(arguments)] + # pylint: disable=g-complex-comprehension + arguments = [[ + put(arg, self._device_ordinals[replica]) for arg in replica_args + ] for replica, replica_args in enumerate(arguments)] return [out.to_py() for out in self.ExecutePerReplica(arguments)] def __del__(self): # Python may have freed c_api first. - if c_api and self._c_computation: - if self._is_compiled: - self._backend.delete_executable(self._c_computation) - else: - assert isinstance(self._c_computation, c_api.LocalComputation) - c_api.DeleteLocalComputation(self._c_computation) + if c_api and self._c_executable: + self._backend.delete_executable(self._c_executable) def _make_replica_group_proto(replica_group): @@ -788,8 +843,8 @@ class ComputationBuilder(object): """XLA computation builder. Enqueues XLA ops in sequence and in order to build a - LocalComputation, which in turn can be compiled into a - CompiledLocalComputation, which in turn can be locally executed. + Computation, which in turn can be compiled into a + LocalExecutable, which in turn can be locally executed. """ # The methods of this class map 1-to-1 onto the XLA C++ @@ -800,16 +855,23 @@ class ComputationBuilder(object): # pylint: disable=g-doc-args def __init__(self, name): - self._client = c_api.LocalComputationBuilder(name.encode('utf8')) + self._client = c_api.ComputationBuilder(name.encode('utf8')) self._parameter_numbering = itertools.count() - def Build(self, root=None, backend=XLA_LOCAL_BACKEND): + def Build(self, root=None, backend=None): + """Builds a `Computation` from the contents of the builder. + + Args: + root: if not None, the operator containing the return value of the + computation. + backend: deprecated. Pass a `backend` to `Computation.Compile` instead. + Returns: + A `Computation`. + """ if root is not None: - return LocalComputation( - self._client.BuildWithRoot(root), is_compiled=False, backend=backend) + return Computation(self._client.BuildWithRoot(root), backend=backend) else: - return LocalComputation( - self._client.Build(), is_compiled=False, backend=backend) + return Computation(self._client.Build(), backend=backend) def SetOpMetadata(self, op_metadata): """Set metadata for operations that are about to be enqueued.""" @@ -1461,7 +1523,7 @@ class ComputationBuilder(object): Args: operand: a LocalOp to test. - Returns: a LocalComputation that is rooted on the given `operand` which is a + Returns: a Computation that is rooted on the given `operand` which is a compile-time constant. """ return self._client.BuildConstantSubGraph(operand) @@ -1662,7 +1724,7 @@ def _forward_methods_to_local_builder(): Set up methods, corresponding to unary and binary XLA operations, whose calls are forwarded in a boilerplate manner to the underlying - LocalComputationBuilder C-extension API. + ComputationBuilder C-extension API. """ def forward_to_local_builder_with_handles(target_method, is_binop=False): @@ -1682,13 +1744,13 @@ def _forward_methods_to_local_builder(): for method_name in _UNARY_OPS: forward = forward_to_local_builder_with_handles( - getattr(c_api.LocalComputationBuilder, method_name)) + getattr(c_api.ComputationBuilder, method_name)) forward.__name__ = method_name setattr(ComputationBuilder, method_name, forward) for method_name in _BINARY_OPS: forward = forward_to_local_builder_with_handles( - getattr(c_api.LocalComputationBuilder, method_name), is_binop=True) + getattr(c_api.ComputationBuilder, method_name), is_binop=True) forward.__name__ = method_name setattr(ComputationBuilder, method_name, forward) @@ -1696,8 +1758,14 @@ def _forward_methods_to_local_builder(): _forward_methods_to_local_builder() +_default_replica_count = 1 + + def initialize_replica_count(replica_count): - """Initializes the desired replica count to use on XLA service init. + """Initializes the default replica count to use. + + Deprecated; pass `num_replicas` as an option to `Computation.Compile()` + instead. Args: replica_count: number of replicas that are desired for set up during XLA @@ -1706,31 +1774,30 @@ def initialize_replica_count(replica_count): Raises: A runtime exception if the XLA service has already been initialized. """ - c_api.InitializeReplicaCount(replica_count) - - -def initialize_platform_name(platform_name): - """Initializes the desired platform name to use on XLA service init. - - Args: - platform_name: string name of platform. - - Raises: - A runtime exception if the XLA service has already been initialized. - A runtime exception if the platform does not exist, or there are no devices - with that platform. - """ - platform_name = _maybe_encode_string(platform_name) - c_api.InitializePlatformName(platform_name) + global _default_replica_count + _default_replica_count = replica_count def get_replica_count(): - """Returns the current replica count used for the XLA service. + """Returns the default replica count. - Note: this will return a value whether the XLA service has been initialized - yet or not. + Deprecated; pass `num_replicas` as an option to `Computation.Compile()` + instead. """ - return c_api.GetReplicaCount() + return _default_replica_count + + +def initialize_platform_name(platform_name): + """Initializes the default platform name to use for XLA. + + Args: + platform_name: string name of platform. + """ + global _default_platform_name + _default_platform_name = platform_name + + # Make sure the platform is valid by trying to instantiate it. + _get_default_local_backend() def register_cpu_custom_call_target(name, fn): diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index aa38c06cf90..f830cb26e3d 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -29,7 +29,7 @@ from tensorflow.compiler.xla.python import xla_client import unittest -class LocalComputationTest(unittest.TestCase): +class ComputationTest(unittest.TestCase): """Base class for running an XLA Computation through the local client.""" def _NewComputation(self, name=None): @@ -85,7 +85,7 @@ def NumpyArrayBool(*args, **kwargs): return np.array(*args, dtype=np.bool, **kwargs) -class ComputationsWithConstantsTest(LocalComputationTest): +class ComputationsWithConstantsTest(ComputationTest): """Tests focusing on Constant ops.""" def testConstantScalarSumS8(self): @@ -304,7 +304,7 @@ class ComputationsWithConstantsTest(LocalComputationTest): self._ExecuteAndCompareClose(c, expected=0.75) -class ParametersTest(LocalComputationTest): +class ParametersTest(ComputationTest): """Tests focusing on Parameter ops and argument-passing.""" def setUp(self): @@ -384,7 +384,7 @@ class ParametersTest(LocalComputationTest): expected=[-4.3, 1.3, -6.3, 3.3]) -class LocalBufferTest(LocalComputationTest): +class LocalBufferTest(ComputationTest): """Tests focusing on execution with LocalBuffers.""" def _Execute(self, c, arguments): @@ -482,7 +482,7 @@ class LocalBufferTest(LocalComputationTest): self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) -class SingleOpTest(LocalComputationTest): +class SingleOpTest(ComputationTest): """Tests for single ops. The goal here is smoke testing - to exercise the most basic functionality of @@ -1175,7 +1175,7 @@ class SingleOpTest(LocalComputationTest): np.testing.assert_allclose(g, expected, rtol=1e-4) -class EmbeddedComputationsTest(LocalComputationTest): +class EmbeddedComputationsTest(ComputationTest): """Tests for XLA graphs with embedded computations (such as maps).""" def _CreateConstantS32Computation(self): @@ -1639,7 +1639,7 @@ class EmbeddedComputationsTest(LocalComputationTest): self._ExecuteAndCompareClose(c, expected=expected) -class ErrorTest(LocalComputationTest): +class ErrorTest(ComputationTest): def setUp(self): self.f32_scalar_2 = NumpyArrayF32(2.0) @@ -1656,7 +1656,7 @@ class ErrorTest(LocalComputationTest): lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2])) -class ComputationRootTest(LocalComputationTest): +class ComputationRootTest(ComputationTest): """Tests related to setting the root of the computation.""" def testComputationRootDifferentFromLastOp(self): diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index a5eae6d3962..33ac51ca4bb 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3529,6 +3529,37 @@ tf_cc_test( ], ) +cc_library( + name = "stable_sort_expander", + srcs = ["stable_sort_expander.cc"], + hdrs = ["stable_sort_expander.h"], + deps = [ + ":hlo", + ":hlo_casting_utils", + ":hlo_pass", + ":op_expander_pass", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "stable_sort_expander_test", + srcs = ["stable_sort_expander_test.cc"], + deps = [ + ":algebraic_simplifier", + ":hlo_matchers", + ":hlo_parser", + ":pattern_matcher", + ":pattern_matcher_gmock", + ":stable_sort_expander", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + cc_library( name = "tuple_util", srcs = ["tuple_util.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index c5deb74e96a..9b037960cda 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -280,15 +280,51 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { hlo)); } - // Helper method to perform and add reduction in a single dimension. - HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + // Converts to primitive type if the input hlo is not that type, otherwise + // returns the original hlo. + HloInstruction* AsType(HloInstruction* hlo, + const PrimitiveType element_type) { + if (hlo->shape().element_type() == element_type) { + return hlo; + } + return computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); + } + + // Transposes a dot operand such that the batch dimensions are the msot major, + // and the contracting dimensions are most minor. + StatusOr NormalizeDotOperandToBatchMajorAndContractingMinor( + HloInstruction* dot_operand, absl::Span batch_dimensions, + absl::Span contracting_dimensions) { + std::vector transpose_dimensions(batch_dimensions.begin(), + batch_dimensions.end()); + for (int64 i = 0; i < dot_operand->shape().rank(); ++i) { + if (!(absl::c_linear_search(batch_dimensions, i) || + absl::c_linear_search(contracting_dimensions, i))) { + transpose_dimensions.push_back(i); + } + } + transpose_dimensions.insert(transpose_dimensions.end(), + contracting_dimensions.begin(), + contracting_dimensions.end()); + return MakeTransposeHlo(dot_operand, transpose_dimensions); + } + + // Helper method to perform and add reduction on a list of dimensions. + HloInstruction* AddReduce(HloInstruction* hlo, absl::Span dims) { HloInstruction* zero = computation_->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(hlo->shape().element_type()).Clone())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); - Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); + Shape shape = ShapeUtil::FilterDimensions( + [&](int64 dim) { return !absl::c_linear_search(dims, dim); }, + hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( - shape, hlo, zero, {dim}, AddReduce_computation)); + shape, hlo, zero, dims, AddReduce_computation)); + } + + HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + return AddReduce(hlo, std::vector{dim}); } // Convenience method for replacing an instruction with a bitcast. If operand @@ -1120,16 +1156,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( std::swap(rhs_collapsing_dim, rhs_kept_dim); } - auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) { - if (hlo->shape().element_type() == element_type) { - return hlo; - } - return computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); - }; - auto reshape_if_necessary = [&](HloInstruction* hlo) { - hlo = as_type(hlo, dot->shape().element_type()); + hlo = AsType(hlo, dot->shape().element_type()); if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { hlo = computation_->AddInstruction( HloInstruction::CreateReshape(dot->shape(), hlo)); @@ -1138,7 +1166,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( }; auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) { - return AddReduce(as_type(hlo, F32), dim); + return AddReduce(AsType(hlo, F32), dim); }; auto broadcast = [&](HloInstruction* hlo, const Shape& shape, @@ -1247,8 +1275,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( return dims; }; - // If the contracting dimension is 1, remove the degnerate dimnesions from the - // lhs and rhs, broadcast each to the result shape and multiply. + // If the contracting dimension is 1, remove the degnerate dimnensions from + // the lhs and rhs, broadcast each to the result shape and multiply. if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 && (rhs_kept_dim == rhs_rank - 1 || (rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) { @@ -1608,34 +1636,26 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // If there are no contracting dimensions, a dot can be rewritten as // mul(broadcast(transpose(x)),broadcast(transpose(y))) if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) { - std::vector lhs_transpose( - dot->dot_dimension_numbers().lhs_batch_dimensions().begin(), - dot->dot_dimension_numbers().lhs_batch_dimensions().end()); - for (int64 i = 0; i < lhs->shape().rank(); ++i) { - if (!absl::c_linear_search( - dot->dot_dimension_numbers().lhs_batch_dimensions(), i)) { - lhs_transpose.push_back(i); - } - } - TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs, - MakeTransposeHlo(lhs, lhs_transpose)); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_lhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + lhs, + AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().lhs_contracting_dimensions()))); if (dot->shape().rank() != lhs->shape().rank()) { std::vector lhs_broadcast_dims(lhs->shape().rank()); absl::c_iota(lhs_broadcast_dims, 0); new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( dot->shape(), new_lhs, lhs_broadcast_dims)); } - std::vector rhs_transpose( - dot->dot_dimension_numbers().rhs_batch_dimensions().begin(), - dot->dot_dimension_numbers().rhs_batch_dimensions().end()); - for (int64 i = 0; i < rhs->shape().rank(); ++i) { - if (!absl::c_linear_search( - dot->dot_dimension_numbers().rhs_batch_dimensions(), i)) { - rhs_transpose.push_back(i); - } - } - TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs, - MakeTransposeHlo(rhs, rhs_transpose)); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_rhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + rhs, + AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().rhs_contracting_dimensions()))); if (dot->shape().rank() != rhs->shape().rank()) { std::vector rhs_broadcast_dims( dot->dot_dimension_numbers().lhs_batch_dimensions_size()); @@ -1651,6 +1671,78 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { new_lhs, new_rhs)); } + // If the lhs or rhs have only batch and contracting dimensions, a dot can be + // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) + if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == + lhs->shape().rank()) || + (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() + + dot->dot_dimension_numbers().rhs_batch_dimensions_size() == + rhs->shape().rank())) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_lhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + lhs, + AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().lhs_contracting_dimensions()))); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_rhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + rhs, + AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().rhs_contracting_dimensions()))); + + int64 lhs_outer_dims = + lhs->shape().rank() - + (dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); + int64 rhs_outer_dims = + rhs->shape().rank() - + (dot->dot_dimension_numbers().rhs_batch_dimensions_size() + + dot->dot_dimension_numbers().rhs_contracting_dimensions_size()); + CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0); + if (rhs_outer_dims > 0) { + std::vector lhs_broadcast_dims( + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); + absl::c_iota(lhs_broadcast_dims, 0); + lhs_broadcast_dims.resize(lhs->shape().rank()); + std::iota(lhs_broadcast_dims.begin() + + dot->dot_dimension_numbers().lhs_batch_dimensions_size(), + lhs_broadcast_dims.end(), + dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + rhs_outer_dims); + new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + new_rhs->shape(), new_lhs, lhs_broadcast_dims)); + } else if (lhs_outer_dims > 0) { + std::vector rhs_broadcast_dims( + dot->dot_dimension_numbers().rhs_batch_dimensions_size()); + absl::c_iota(rhs_broadcast_dims, 0); + rhs_broadcast_dims.resize(rhs->shape().rank()); + std::iota(rhs_broadcast_dims.begin() + + dot->dot_dimension_numbers().rhs_batch_dimensions_size(), + rhs_broadcast_dims.end(), + dot->dot_dimension_numbers().rhs_batch_dimensions_size() + + lhs_outer_dims); + new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + new_lhs->shape(), new_rhs, rhs_broadcast_dims)); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, + MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs)); + std::vector reduce_dims( + dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); + new_dot = AsType(new_dot, F32); + const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims); + absl::c_iota( + reduce_dims, + outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); + new_dot = AddReduce(new_dot, reduce_dims); + new_dot = AsType(new_dot, dot->shape().element_type()); + return ReplaceInstruction(dot, new_dot); + } + if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 || dot->shape().rank() > 2) { if (options_.enable_dot_strength_reduction() && diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index feb6a0fb795..d959fafc0c0 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2753,8 +2753,9 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { Shape keys_shape = ShapeUtil::MakeShape(F32, {1}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - TF_ASSERT_OK( - MakeSortHlo(keys_shape, {keys}, 0, &builder, module.get()).status()); + TF_ASSERT_OK(MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, &builder, + module.get()) + .status()); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -2775,7 +2776,8 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { HloInstruction::CreateParameter(2, values_shape, "values1")); TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape( {keys_shape, values_shape, values_shape}), - {keys, values0, values1}, 0, &builder, module.get()) + {keys, values0, values1}, 0, /*is_stable=*/false, + &builder, module.get()) .status()); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); @@ -3712,8 +3714,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(0); builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums, DefaultPrecisionConfig(2))); std::unique_ptr dot_computation(builder.Build()); @@ -4220,12 +4222,24 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { int m, k, n; PrimitiveType element_type; std::tie(m, k, n, element_type) = GetParam(); - - Shape dot_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, n}); - Shape lhs_shape = k > 0 ? ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k}) - : ShapeUtil::MakeShape(element_type, {1, 3, 5, m}); - Shape rhs_shape = k > 0 ? ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n}) - : ShapeUtil::MakeShape(element_type, {1, 3, 5, n}); + std::vector lhs_dims = {1, 3, 5}; + std::vector rhs_dims = lhs_dims; + std::vector output_dims = lhs_dims; + if (m > 0) { + lhs_dims.push_back(m); + output_dims.push_back(m); + } + if (k > 0) { + lhs_dims.push_back(k); + rhs_dims.push_back(k); + } + if (n > 0) { + rhs_dims.push_back(n); + output_dims.push_back(n); + } + Shape dot_shape = ShapeUtil::MakeShape(element_type, output_dims); + Shape lhs_shape = ShapeUtil::MakeShape(element_type, lhs_dims); + Shape rhs_shape = ShapeUtil::MakeShape(element_type, rhs_dims); HloComputation::Builder builder(TestName()); auto lhs = builder.AddInstruction( @@ -4240,7 +4254,7 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { dot_dnums.add_rhs_batch_dimensions(1); dot_dnums.add_rhs_batch_dimensions(2); if (k > 0) { - dot_dnums.add_lhs_contracting_dimensions(4); + dot_dnums.add_lhs_contracting_dimensions(m > 0 ? 4 : 3); dot_dnums.add_rhs_contracting_dimensions(3); } builder.AddInstruction(HloInstruction::CreateDot( @@ -4248,9 +4262,9 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); - const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1 || k == -1; - const bool computation_should_be_modified = dot_should_be_transformed; - EXPECT_EQ(changed, computation_should_be_modified); + const bool dot_should_be_transformed = + m == 1 || k == 1 || n == 1 || m == -1 || k == -1 || n == -1; + EXPECT_EQ(changed, dot_should_be_transformed); bool has_no_dot = true; for (const auto& hlo : computation->instructions()) { if (hlo->opcode() == HloOpcode::kDot) { @@ -4261,10 +4275,12 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { EXPECT_EQ(has_no_dot, dot_should_be_transformed); } -INSTANTIATE_TEST_SUITE_P( - BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest, - ::testing::Combine(::testing::Values(1, 2), ::testing::Values(-1, 1, 2), - ::testing::Values(1, 2), ::testing::Values(F32, BF16))); +INSTANTIATE_TEST_SUITE_P(BatchDotStrengthReductionTestInstantiation, + BatchDotStrengthReductionTest, + ::testing::Combine(::testing::Values(-1, 1, 2), + ::testing::Values(-1, 1, 2), + ::testing::Values(-1, 1, 2), + ::testing::Values(F32, BF16))); class DotStrengthReductionTest : public AlgebraicSimplifierTest, diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index 99373dc107a..52d6982c70f 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -32,15 +32,13 @@ limitations under the License. namespace xla { -namespace { - namespace m = match; // Checks if the argument instruction is an AllReduce, followed by a certain // sequence of instructions and then a CRS. It must be possible to move // the AR past each instruction in the sequence. Returns the CRS, which is the // last instruction in the sequence. -absl::optional MatchesArCrsPattern( +absl::optional ArCrsCombiner::MatchesArCrsPattern( HloInstruction* instruction) { auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool { if (instruction->user_count() != 1) { @@ -77,23 +75,23 @@ absl::optional MatchesArCrsPattern( return absl::nullopt; } auto next = instruction->users()[0]; + int64 distance = 1; while (!next->IsCrossReplicaAllReduce()) { if (can_ar_move_past_instruction(next)) { next = next->users()[0]; } else { return absl::nullopt; } + ++distance; } if (!Cast(next)->IsNoop() && computation_is_addition(next->called_computations()[0])) { - return absl::optional(next); + return absl::optional(ArCrsPair(instruction, next, distance)); } else { return absl::nullopt; } } -} // namespace - absl::optional ArCrsCombiner::WhileFromBodyParameter( HloInstruction* instruction) { CHECK_EQ(HloOpcode::kParameter, instruction->opcode()); @@ -235,15 +233,55 @@ bool ArCrsCombiner::InstructionsComputeSameValue( } void ArCrsCombiner::GroupAllReducesById(HloModule* module) { + // Say that two or more ARs lead to the same CRS: (AR1, CRS), (AR2, CRS), + // ... , (ARn, CRS). + // If as we traverse the HLO graph we start tracking the pair (AR2, CRS), + // and later find that AR1's distance from the CRS is longer, we discard + // AR2 and start tracking AR1. We put the discarded ids in this set, in order + // to skip processing of short paths when we encounter the other ARs that + // have the same id as AR2. + absl::flat_hash_set discarded_ar_ids; for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { - auto maybe_crs = MatchesArCrsPattern(instruction); - if (maybe_crs) { - auto crs = *maybe_crs; + auto maybe_pair = MatchesArCrsPattern(instruction); + if (maybe_pair) { + auto pair = *maybe_pair; int64 ar_id = *(instruction->all_reduce_id()); - if (crs_reserved_map_.find(crs) == crs_reserved_map_.end()) { - all_reduce_map_[ar_id].push_back(instruction); - crs_reserved_map_[crs] = ar_id; + if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) { + continue; + } + auto it = crs_reserved_map_.find(pair.crs); + if (it != crs_reserved_map_.end()) { + auto prev_ar_id = it->second; + // Since there is another AR paired with CRS, + // all_reduce_map_[prev_ar_id] should exist, but + // all_reduce_map_[ar_id] shouldn't. + CHECK(all_reduce_map_.find(ar_id) == all_reduce_map_.end()); + CHECK_NE(prev_ar_id, ar_id); + auto prev_pair = all_reduce_map_[prev_ar_id].back(); + int64 prev_distance = prev_pair.distance; + if (prev_distance < pair.distance) { + // The current AR's distance to CRS is longer than the previously + // tracked AR, so we discard the previous AR. + all_reduce_map_.erase(prev_ar_id); + discarded_ar_ids.insert(prev_ar_id); + all_reduce_map_[ar_id].push_back(pair); + crs_reserved_map_[pair.crs] = ar_id; + } else { + // Discard the current AR id because we are keeping the previously + // tracked AR. + discarded_ar_ids.insert(ar_id); + } + } else { + if (all_reduce_map_.find(ar_id) != all_reduce_map_.end()) { + int64 prev_distance = all_reduce_map_[ar_id].back().distance; + CHECK_EQ(prev_distance, pair.distance) + << "All ARs with the same AR ID must have the same distance " + "from the corresponding CRSs. Found: " + << prev_distance << " and " << pair.distance; + } + all_reduce_map_[ar_id].push_back(pair); + crs_reserved_map_[pair.crs] = ar_id; } } } @@ -253,11 +291,11 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { for (auto it : all_reduce_map_) { auto all_reduce_id = it.first; - auto instruction_vec = it.second; - CHECK_EQ(instruction_vec.size(), num_spatial_partitions_); - auto instr_0 = instruction_vec[0]; - for (int i = 1; i < instruction_vec.size(); ++i) { - auto instr_i = instruction_vec[i]; + auto pairs_vec = it.second; + CHECK_EQ(pairs_vec.size(), num_spatial_partitions_); + auto instr_0 = pairs_vec[0].ar; + for (int i = 1; i < pairs_vec.size(); ++i) { + auto instr_i = pairs_vec[i].ar; auto next_0 = instr_0->users()[0]; auto next_i = instr_i->users()[0]; absl::flat_hash_map visited_pairs; @@ -281,8 +319,9 @@ StatusOr ArCrsCombiner::RewriteGraph() { return false; } for (auto it : all_reduce_map_) { - auto instruction_vec = it.second; - for (auto all_reduce : instruction_vec) { + auto pairs_vec = it.second; + for (auto pair : pairs_vec) { + auto all_reduce = pair.ar; auto parent_computation = all_reduce->parent(); auto all_reduce_id = all_reduce->all_reduce_id(); auto prev = all_reduce->mutable_operand(0); @@ -303,16 +342,23 @@ StatusOr ArCrsCombiner::RewriteGraph() { ? next->operands()[1] : next->operands()[0]; // To move the AR past the addition/subtraction, we need to divide - // other_operand by the number of spatial partitions. - auto shape = other_operand->shape(); - Literal lit(shape); - lit.PopulateWithValue(num_spatial_partitions_); - auto divisor = parent_computation->AddInstruction( - HloInstruction::CreateConstant(lit.Clone())); - auto division = - parent_computation->AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDivide, other_operand, divisor)); - TF_CHECK_OK(other_operand->ReplaceUseWith(next, division)); + // other_operand by the number of spatial partitions, except if + // other_operand is a cross-module AR, which can be eliminated. + if (other_operand->IsCrossModuleAllReduce() && + other_operand->user_count() == 1) { + TF_CHECK_OK(other_operand->ReplaceAllUsesWith( + other_operand->mutable_operand(0))); + } else { + auto shape = other_operand->shape(); + Literal lit(shape); + lit.PopulateWithValue(num_spatial_partitions_); + auto divisor = parent_computation->AddInstruction( + HloInstruction::CreateConstant(lit.Clone())); + auto division = parent_computation->AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDivide, + other_operand, divisor)); + TF_CHECK_OK(other_operand->ReplaceUseWith(next, division)); + } break; } default: diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h index e61ef5d4f90..f503e1d5f2b 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.h +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -26,11 +26,47 @@ limitations under the License. namespace xla { // When the HLO graph contains a cross-module AllReduce, followed by some simple -// linear operations, followed by a cross-replica AllReduce, we can combine the -// CMAR and the CRAR, to use an efficient AllReduce implementation that fully -// utilizes the interconnect bandwidth. +// linear operations, followed by a cross-replica AllReduce (also known as +// cross-replica sum, or CRS), we can combine the CMAR and the CRAR, to use an +// efficient AllReduce implementation that fully utilizes the interconnect +// bandwidth. // Such sequences appear in spatially partitioned models. -// This pass must run right after spatial partitioning. +// This pass must run right after spatial partitioning, when the code is still +// in a single HLO module. +// +// The steps are: +// 1) Find CMARs followed by simple ops followed by CRARs. +// 2) Group CMARs by all_reduce_id. They must all be rewritten. +// 3) Prove that the CMAR patterns in each core produce the same result. +// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the +// other operand by the number of spatial partitions. +// 5) Turn the CRAR into an all-core AllReduce. +// +// The pass also handles the case where multiple CMARs lead to the same CRAR, +// and eliminates all CMARs. This graph: +// +// Y +// | +// X CMAR_2 Z +// | \ / +// CMAR_1 + +// \ / +// + +// | +// CRAR +// +// gets rewritten to: +// +// Z num_partitions +// \ / +// Y div +// \ / +// X + +// \ / +// + +// | +// all-core AR +// class ArCrsCombiner : public HloModulePass { public: ArCrsCombiner(int num_spatial_partitions) @@ -43,6 +79,28 @@ class ArCrsCombiner : public HloModulePass { HloInstruction* i2); private: + // We used this struct because multiple ARs could be paired with the same CRS. + // In this case, we want to select the AR that is furthest from the CRS, + // because it makes it easier to eliminate all ARs during RewriteGraph. + struct ArCrsPair { + HloInstruction* ar; + HloInstruction* crs; + // The length of the path from AR to CRS in the HLO graph. + int64 distance; + + ArCrsPair(HloInstruction* all_reduce, HloInstruction* cross_replica_sum, + int64 dist) + : ar(all_reduce), crs(cross_replica_sum), distance(dist) {} + + string ToString() { + return absl::StrCat("(AR: ", ar->name(), ", CRS: ", crs->name(), + ", distance: ", distance, ")"); + } + }; + + absl::optional MatchesArCrsPattern( + HloInstruction* instruction); + // If the passed instruction is a while parameter, and the while body is only // called by a single while instruction, return the while instruction. absl::optional WhileFromBodyParameter( @@ -80,8 +138,8 @@ class ArCrsCombiner : public HloModulePass { int num_spatial_partitions_; - // Map from all-reduce ids to the all reduce instructions. - absl::flat_hash_map> all_reduce_map_; + // Map from all-reduce ids to the AR/CRS pairs. + absl::flat_hash_map> all_reduce_map_; // Map from a CRS instruction to the all-reduce ID of the AR paired with the // CRS. Sometimes, several ARs in the code could be paired with the same CRS. diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index 5152f0dc884..9c9db74fd2f 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -1005,11 +1005,11 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { op::Tuple(op::AllReduce(op::Add( op::Add(op::Parameter(), op::Divide(op::Constant(), op::Constant())), - op::Divide(op::AllReduce(), op::Constant()))), + op::Parameter())), op::AllReduce(op::Add( op::Add(op::Parameter(), op::Divide(op::Constant(), op::Constant())), - op::Divide(op::AllReduce(), op::Constant()))))); + op::Parameter())))); auto crs_after = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_after = crs_after->replica_groups(); @@ -1093,15 +1093,17 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { ArCrsCombiner combiner(2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::AllReduce(op::Add( - op::Parameter(), - op::Divide(op::Add(op::AllReduce(), op::Constant()), - op::Constant()))), - op::AllReduce(op::Add( - op::Parameter(), - op::Divide(op::Add(op::AllReduce(), op::Constant()), - op::Constant()))))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Parameter(), + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())))), + op::AllReduce(op::Add( + op::Parameter(), + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())))))); + auto crs_after = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_after = crs_after->replica_groups(); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 2591ff602c8..2caa979745b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -286,7 +286,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { TF_ASSERT_OK_AND_ASSIGN( auto* sort, MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), - {key, value}, 0, &builder, module.get())); + {key, value}, 0, /*is_stable=*/false, &builder, + module.get())); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); @@ -314,7 +315,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { TF_ASSERT_OK_AND_ASSIGN( auto* sort, MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}), - {key, value}, 0, &builder, module.get())); + {key, value}, 0, /*is_stable=*/false, &builder, + module.get())); auto computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 3676de56a30..cc9489daaee 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -673,9 +673,9 @@ StatusOr> CpuCompiler::RunBackend( if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } - TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); + TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 0fecbaf391b..2bf22ec6e43 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -963,8 +963,8 @@ Status EmitBatchDotOperation( KernelSupportLibrary ksl(b); return ksl.ForWithStatus( - "bdot", /*start=*/0, /*end=*/batch_count, /*step=*/1, - [&](llvm::Value* indvar) { + llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count, + /*step=*/1, [&](llvm::Value* indvar) { DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers(); adjusted_dim_numbers.clear_lhs_batch_dimensions(); adjusted_dim_numbers.clear_rhs_batch_dimensions(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 5abb3eb3872..9967cf28ee2 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -583,7 +583,7 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { b_.getVoidTy(), {b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(), - b_.getInt32Ty()->getPointerTo(), b_.getInt8PtrTy(), + b_.getInt32Ty()->getPointerTo(), b_.getInt1Ty(), b_.getInt8PtrTy(), b_.getInt64Ty()->getPointerTo(), less_than_function->getType()}, /*isVarArg=*/false); auto* key_value_sort_func = llvm::dyn_cast( @@ -616,8 +616,8 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), b_.getInt64(lower_dimensions), values, b_.getInt32(sort->operand_count()), sizes, - GetExecutableRunOptionsArgument(), GetProfileCountersArgument(), - less_than_function}); + b_.getInt1(sort->is_stable()), GetExecutableRunOptionsArgument(), + GetProfileCountersArgument(), less_than_function}); if (sort->values_count() > 0) { llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index cb46674138a..70a6d0af02c 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -32,8 +32,8 @@ using tensorflow::int64; TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes, char* run_options, - int64* prof_counters, + int32* values_primitive_type_size_in_bytes, bool is_stable, + char* run_options, int64* prof_counters, void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)) { // 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT // code, so msan can't tell they are initialized. @@ -69,22 +69,27 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( int64 base_offset = index % sort_dimension_offset + (index - index % sort_dimension_offset) * sort_dimension_elements; - std::stable_sort( - indices.get(), indices.get() + sort_dimension_elements, - [&](int64 a, int64 b) -> bool { - int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * - values_primitive_type_size_in_bytes[0]; - int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * - values_primitive_type_size_in_bytes[0]; - for (int32 i = 0; i < values_count; ++i) { - comparison_values[i * 2] = values[i] + memory_index_lhs; - comparison_values[i * 2 + 1] = values[i] + memory_index_rhs; - } - char result = 0; // Overwritten by less_than. - less_than(&result, run_options, comparison_values.get(), nullptr, - prof_counters); - return result != 0u; - }); + auto compare_function = [&](int64 a, int64 b) -> bool { + int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + for (int32 i = 0; i < values_count; ++i) { + comparison_values[i * 2] = values[i] + memory_index_lhs; + comparison_values[i * 2 + 1] = values[i] + memory_index_rhs; + } + char result = 0; // Overwritten by less_than. + less_than(&result, run_options, comparison_values.get(), nullptr, + prof_counters); + return result != 0u; + }; + if (is_stable) { + std::stable_sort(indices.get(), indices.get() + sort_dimension_elements, + compare_function); + } else { + std::sort(indices.get(), indices.get() + sort_dimension_elements, + compare_function); + } // Reorder the values according to the order defined by 'indices'. for (int32 idx = 0; idx < values_count; ++idx) { diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h index 4813de9ee67..50c2911c3bd 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h @@ -22,15 +22,14 @@ limitations under the License. extern "C" { // Each entry in 'values' represents a 3-dimensional shape with dimensions -// [a, b, c]. The 'b' dimension of the first shape is sorted into ascending -// order according to the results of comparisons using the provided 'less_than' +// [a, b, c]. The 'b' dimension of each shape is sorted into ascending order +// according to the results of comparisons using the provided 'less_than' // function. 'values_count' must be > 0 and specifies the number of entries in // 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive // type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]' -// bytes. The elements in each 'values' shape are reordered in the same way -// according to the comparisons using the first shape. 'run_options' and -// 'prof_counters' are passed through to the less-than function, which expects -// the following arguments: +// bytes. 'is_stable' specifies whether the sorting should be stable. +// 'run_options' and 'prof_counters' are passed through to the less-than +// function, which expects the following arguments: // - pointer to the return value buffer (char*) // - xla::ExecutableRunOptions = 'run_options' (char*) // - pointers to the parameter buffers (char**) @@ -39,8 +38,8 @@ extern "C" { extern void __xla_cpu_runtime_KeyValueSort( tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes, char* run_options, - tensorflow::int64* prof_counters, + tensorflow::int32* values_primitive_type_size_in_bytes, bool is_stable, + char* run_options, tensorflow::int64* prof_counters, void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)); } diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc index eb6c44b70ab..9fc472ff767 100644 --- a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc @@ -938,6 +938,53 @@ void TiledSmallGemmEmitter::EmitTiledGemm( }); } +llvm::Type* GetPointerToElementType(llvm::Type* pointer_type) { + llvm::Type* type = + llvm::cast(pointer_type)->getElementType(); + while (auto* array_type = llvm::dyn_cast(type)) { + type = array_type->getElementType(); + } + + return type->getPointerTo(); +} + +struct GemvBuffersWithCanonicalType { + llvm::Value* lhs_canonicalized; + llvm::Value* rhs_canonicalized; + llvm::Value* addend_canonicalized; + llvm::Value* result_canonicalized; +}; + +GemvBuffersWithCanonicalType GetGemvBuffersWithCanonicalType( + llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b) { + // We characterize a GEMV operation via M and K, since N is implicitly 1. + // This means the GEMV that multiplies (say) [5,6] with [6,1] is implemented + // by the same GEMV that multiplies [5,6] with [1,6]. However, the + // `llvm::Types` for the inputs to the two GEMVs don't match (in a trivial + // sense -- the in memory representations are the same) since they're computed + // from the `xla::Shape`s. Since we want to be able to call the same + // `llvm::Function` for the two GEMVs we canonicalize the types of the GEMV + // inputs here into the same type. + GemvBuffersWithCanonicalType buffers_with_canonical_type; + llvm::Type* lhs_type = lhs->getType(); + llvm::Type* rhs_type = rhs->getType(); + llvm::Type* addend_type = addend ? addend->getType() : nullptr; + llvm::Type* result_type = result->getType(); + + buffers_with_canonical_type.lhs_canonicalized = + b->CreateBitCast(lhs, GetPointerToElementType(lhs_type)); + buffers_with_canonical_type.rhs_canonicalized = + b->CreateBitCast(rhs, GetPointerToElementType(rhs_type)); + buffers_with_canonical_type.addend_canonicalized = + addend ? b->CreateBitCast(addend, GetPointerToElementType(addend_type)) + : nullptr; + buffers_with_canonical_type.result_canonicalized = + b->CreateBitCast(result, GetPointerToElementType(result_type)); + + return buffers_with_canonical_type; +} + } // namespace void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows, @@ -950,12 +997,18 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows, /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); + GemvBuffersWithCanonicalType canonical_inputs = + GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b); + KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, - rhs, addend, result, - [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result) { + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), + canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized, + canonical_inputs.addend_canonicalized, + canonical_inputs.result_canonicalized, + [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* addend, + llvm::Value* result) { RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, result, b); emitter.Emit(); @@ -972,12 +1025,18 @@ void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows, /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); + GemvBuffersWithCanonicalType canonical_inputs = + GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b); + KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, - rhs, addend, result, - [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result) { + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), + canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized, + canonical_inputs.addend_canonicalized, + canonical_inputs.result_canonicalized, + [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* addend, + llvm::Value* result) { ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, result, b); emitter.Emit(); diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index e868dc6d889..808929be75e 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1367,26 +1367,69 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( llvm_ir::PrimitiveTypeToIrType(elem_prim_ty, module_); llvm::Type* raw_value_ty = raw_value->getType(); - // Convert raw integer to float in range [0, 1) if the element is a float. + // If we're generating a floating-point value, convert the raw integer R (i.e. + // `raw_value`) to a float in the range [0, 1). + // + // The basic approach is to choose a significand and exponent such that the + // significand is uniformly distributed and the exponent is distributed, well, + // exponentially (it's more likely to be close to 0 than far from 0). + // + // An easy way to do this is to say that the significand is the first S bits + // of R, and the exponent is determined by the number of trailing zeroes in R, + // exp = 2^-(cttz(R) + 1). (+1 because the largest exponent should be -1; + // this way the largest value we can return is 1.999... * 2^-1 = 1-ε.) + // + // This results in a small bias. Namely, if R has enough trailing zeroes, the + // significand and exponent will "overlap". As a concrete example, consider + // + // 20 X's 12 zeroes + // R = 0bXXXXXXXXXXXXXXXXXXXX000000000000 + // + // Here the exponent is 2^-13 because R has 12 trailing zeroes. The + // significand is made up of the first 23 most-significant bits of R, which we + // observe contain 3 zeroes. This is biased because any random value with + // exponent 2^-12 will have a significand which ends in `000`. + // + // For f32s, this problem occurs only when there are more than 32-23 = 9 + // trailing zeros, which happens with probability 0.5^10 = ~0.1%. Moreover the + // probability of a large bias (i.e. many trailing 0s in the significand) is + // exponentially low. So we deem this acceptable. llvm::Value* elem_value = raw_value; if (elem_ir_ty->isFloatingPointTy()) { - unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits(); - CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64); - // Perform the division using the float type with the same number of bits - // as the raw value to avoid overflow. - if (raw_value_size_in_bits == 32) { - elem_value = UIToFP(elem_value, b_->getFloatTy()); - elem_value = FDiv(elem_value, - llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); - } else { - elem_value = UIToFP(elem_value, b_->getDoubleTy()); - elem_value = FDiv( - elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64))); - } + const auto& dest_flt_semantics = elem_ir_ty->getFltSemantics(); + const int bits = raw_value_ty->getPrimitiveSizeInBits(); + CHECK_GE(bits, llvm::APFloat::semanticsSizeInBits(dest_flt_semantics)); - if (elem_ir_ty != elem_value->getType()) { - elem_value = FPTrunc(elem_value, elem_ir_ty); - } + // Subtract 1 because semanticsPrecision includes the "hidden bit", i.e. the + // implicit "1." at the beginning of the significand. + const int significand_bits = + llvm::APFloat::semanticsPrecision(dest_flt_semantics) - 1; + + llvm::Value* cttz = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::cttz, {raw_value, /*is_zero_undef=*/b_->getFalse()}, + {raw_value->getType()}, b_); + llvm::Value* significand = LShr(raw_value, bits - significand_bits); + + // Exponent bias is -127 for f32, meaning that if the exponent is E and the + // significand is S, then the value of the number is 2^(E - 127) * (1.S). + // + // We want cttz == 0 to correspond to 2^-1, so our exponent is computed as + // E = 126 - cttz. + // + // For f64, this is all the same, except the bias is -1023. + // + // In IEEE floating point, the absolute value of the exponent bias equals + // the value of the largest possible exponent. + const int bias = -llvm::APFloat::semanticsMaxExponent(dest_flt_semantics); + llvm::Value* exponent = + Sub(llvm::ConstantInt::get(cttz->getType(), -bias - 1), cttz); + + // Now just slot everything into place! The `Trunc` is here because + // raw_value may be larger than our float destination. + elem_value = + BitCast(Trunc(Or(Shl(exponent, significand_bits), significand), + b_->getIntNTy(elem_ir_ty->getPrimitiveSizeInBits())), + elem_ir_ty); } // Convert the value for the requested distribution. diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index d3e2acaabd4..7d360fe38cf 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -216,8 +216,11 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator); + // Converts the raw value generated by a random number generation algorithm // to the distribution requested by the RNG HloInstruction. + // + // Precondition: raw_value has at least as many bits as hlo's element type. StatusOr ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 05980fe549c..25c4f70d89b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -765,6 +765,7 @@ cc_library( "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:sort_simplifier", + "//tensorflow/compiler/xla/service:stable_sort_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 9c8a1816040..6e00e4b4ff8 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -82,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/sort_simplifier.h" +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -195,6 +196,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddPass( cost_model, /*convert_batch_groups_only=*/true); + // Expand the sort op to support stable sorting if required. + pipeline.AddPass(); // Convert BF16 operations to F32 operations so that the GPU backend can // support BF16 operations without directly implementing a BF16 lowering for // most ops. diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 6e64549e7e1..d2c995d87ad 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 60 +// Next ID: 61 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -175,6 +175,9 @@ message HloInstructionProto { // partners. bool is_host_transfer = 47; + // Whether this Sort instruction should be stable. + bool is_stable = 60; + xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; // Precision configuration for the instruction. Has backend-specific meaning. diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 070115604ba..b5d9e8e7f1a 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -275,7 +275,7 @@ StatusOr MakeSelectHlo(HloInstruction* pred, StatusOr MakeSortHlo( const Shape& sort_shape, absl::Span operands, - int64 dimension_to_sort, HloComputation::Builder* builder, + int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder, HloModule* module) { CHECK(!operands.empty()) << "Sort Hlo requires at least one operand."; HloComputation* compare_computation; @@ -293,7 +293,7 @@ StatusOr MakeSortHlo( compare_computation = module->DeepCloneComputation(new_module->entry_computation(), &context); return builder->AddInstruction(HloInstruction::CreateSort( - sort_shape, dimension_to_sort, operands, compare_computation)); + sort_shape, dimension_to_sort, operands, compare_computation, is_stable)); } StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 36b8cdc7fef..17b7a2da6a9 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -126,10 +126,10 @@ StatusOr MakeSelectHlo(HloInstruction* pred, // Creates a Sort HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. Also creates a // default compare sub-computation which sorts the first operand into ascending -// order. +// order. 'is_stable' specifies whether the sorting should be stable. StatusOr MakeSortHlo( const Shape& sort_shape, absl::Span operands, - int64 dimension_to_sort, HloComputation::Builder* builder, + int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder, HloModule* module); // Creates an R1 Constant HLO instruction of the given PrimitiveType with the diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index e3059e02cf0..768e3afb3b8 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2363,7 +2363,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); TF_ASSERT_OK_AND_ASSIGN( - auto* sort, MakeSortHlo(keys_shape, {keys}, -1, &builder, module_.get())); + auto* sort, MakeSortHlo(keys_shape, {keys}, -1, /*is_stable=*/false, + &builder, module_.get())); computation_ = module_->AddEntryComputation(builder.Build()); RunAnalysis(); @@ -2385,7 +2386,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { TF_ASSERT_OK_AND_ASSIGN( auto* sort, MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}), - {keys, values}, 0, &builder, module_.get())); + {keys, values}, 0, /*is_stable=*/false, &builder, + module_.get())); computation_ = module_->AddEntryComputation(builder.Build()); RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 8def61dc63d..e0a0fc4acb3 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -2670,12 +2670,25 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& high = parent_->GetEvaluatedLiteralFor(random->operand(1)); - std::uniform_real_distribution generator( - low.Get({}), high.Get({})); - + // std::uniform_real_distribution(a, b) can sometimes return a value + // equal to b. Unclear if this is a spec bug or an implementation bug + // or WAI [0] [1] [2]. Anyway for our purposes we want a half-open + // interval, so we have to re-sample if we get `b` out. + // + // [0] https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176 + // [1] https://bugs.llvm.org/show_bug.cgi?id=18767 + // [2] http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524 + auto low_val = low.Get({}); + auto high_val = high.Get({}); + std::uniform_real_distribution generator(low_val, high_val); TF_RETURN_IF_ERROR( result.Populate([&](absl::Span /*indexes*/) { - return generator(parent_->engine_); + while (true) { + NativeT v = generator(parent_->engine_); + if (v != high_val) { + return v; + } + } })); break; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index aa1f3a2421f..8ece90e05cc 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -214,7 +214,7 @@ StatusOr> HloInstruction::CreateFromProto( << proto.called_computation_ids_size(); auto sort_operands = all_operands(); instruction = CreateSort(shape, proto.dimensions(0), all_operands(), - computations(0)); + computations(0), proto.is_stable()); break; } case HloOpcode::kTranspose: @@ -1170,9 +1170,10 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateSort( const Shape& shape, int64 dimension, - absl::Span operands, HloComputation* compare) { + absl::Span operands, HloComputation* compare, + bool is_stable) { return absl::make_unique(shape, dimension, operands, - compare); + compare, is_stable); } /* static */ std::unique_ptr HloInstruction::CreateFusion( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e8ade80ef38..8470cf7ec53 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -384,6 +384,14 @@ class HloInstruction { // Creates a random number generation instruction that fills a shape with // random numbers from a given distribution. + // + // The parameters to the instruction are interpreted as follows: + // + // - If `distribution` is RNG_UNIFORM, generates a number in range + // [param0, param1). + // + // - If `distribution` is RNG_NORMAL, generates a normally-distributed value + // with mean `param0` and standard deviation `param1`. static std::unique_ptr CreateRng( const Shape& shape, RandomDistribution distribution, absl::Span parameters); @@ -678,10 +686,11 @@ class HloInstruction { // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters, // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at // specific index positions which should be compared, and should return a - // PRED. + // PRED. 'is_stable' specifies whether stable sorting is required. static std::unique_ptr CreateSort( const Shape& shape, int64 dimension, - absl::Span operands, HloComputation* compare); + absl::Span operands, HloComputation* compare, + bool is_stable); // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For @@ -1286,6 +1295,9 @@ class HloInstruction { backend_config_ = std::move(config_str); } + bool is_default_config() const { return is_default_config_; } + void set_default_config() { is_default_config_ = true; } + // Returns a string representation of a proto in the format used by // raw_backend_config_string. // @@ -1734,6 +1746,10 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; + // This field is assigned to true when backend_config_ is assigned to + // a default configuration. + bool is_default_config_ = false; + // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 92a74187c50..7c8d98b4299 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -659,8 +659,11 @@ std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( HloSortInstruction::HloSortInstruction( const Shape& shape, int64 dimension, - absl::Span operands, HloComputation* compare) - : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) { + absl::Span operands, HloComputation* compare, + bool is_stable) + : HloInstruction(HloOpcode::kSort, shape), + dimensions_({dimension}), + is_stable_(is_stable) { for (auto* value : operands) { AppendOperand(value); } @@ -672,12 +675,18 @@ HloInstructionProto HloSortInstruction::ToProto() const { for (int64 dimension : dimensions_) { proto.add_dimensions(dimension); } + proto.set_is_stable(is_stable()); return proto; } std::vector HloSortInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; + std::vector attrs; + attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}")); + if (is_stable()) { + attrs.push_back("is_stable=true"); + } + return attrs; } bool HloSortInstruction::IdenticalSlowPath( @@ -688,14 +697,17 @@ bool HloSortInstruction::IdenticalSlowPath( if (dimensions() != casted_other.dimensions()) { return false; } + if (is_stable() != casted_other.is_stable()) { + return false; + } return eq_computations(to_apply(), other.to_apply()); } std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return absl::make_unique(shape, dimensions(0), - new_operands, to_apply()); + return absl::make_unique( + shape, dimensions(0), new_operands, to_apply(), is_stable()); } HloTransposeInstruction::HloTransposeInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index a0f2b46ba41..8bb37ab4359 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -447,7 +447,7 @@ class HloSortInstruction : public HloInstruction { public: explicit HloSortInstruction(const Shape& shape, int64 dimension, absl::Span operands, - HloComputation* compare); + HloComputation* compare, bool is_stable); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -460,6 +460,7 @@ class HloSortInstruction : public HloInstruction { HloInstruction* mutable_keys() { return mutable_operand(0); } // Returns the number of value operands. int64 values_count() const { return operand_count() - 1; } + bool is_stable() const { return is_stable_; } private: std::vector ExtraAttributesToStringImpl( @@ -474,6 +475,7 @@ class HloSortInstruction : public HloInstruction { HloCloneContext* context) const override; std::vector dimensions_; + bool is_stable_; }; class HloTransposeInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 20dbed07c54..b8e699fee2f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -895,6 +895,8 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; + optional is_stable = false; + attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable}; optional to_apply; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; @@ -902,8 +904,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, dimensions->size() != 1) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), operands, to_apply.value())); + instruction = builder->AddInstruction( + HloInstruction::CreateSort(shape, dimensions->at(0), operands, + to_apply.value(), is_stable.value())); break; } case HloOpcode::kTuple: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 203a7dba221..4b9453cfd78 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1145,6 +1145,24 @@ ENTRY Sort { ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare } +)" +}, +// Sort (Key) is_stable=true +{ +"SortKeyStable", +R"(HloModule sort + +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + +ENTRY Sort { + x = f32[1024]{0} parameter(0) + ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare +} + )" }, // Conditional diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 84399f17e5e..5a5401e3513 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -176,7 +176,7 @@ StatusOr HloRunner::Execute( TransferLiteralsToDevice(arguments)); TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, ExecuteWithDeviceBuffers( - /*module=*/std::move(executable), + /*executable=*/executable.get(), /*arguments=*/argument_buffers, /*profile=*/profile)); return TransferLiteralFromDevice(result); @@ -235,7 +235,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( } StatusOr HloRunner::ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile) { // Get service run options. @@ -254,7 +254,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( } StatusOr HloRunner::ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile) { std::vector argument_pointers; diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index a6e6015d6a5..fb897aa9599 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -144,13 +144,16 @@ class HloRunner { const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + // In the following two calls, "executable" is not a unique_ptr to allow + // reuse of the Executable. This call may update the profile information in + // *executable. StatusOr ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile = nullptr); StatusOr ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile = nullptr); diff --git a/tensorflow/compiler/xla/service/op_expander_pass.cc b/tensorflow/compiler/xla/service/op_expander_pass.cc index 87f0886a973..02c9d4b387b 100644 --- a/tensorflow/compiler/xla/service/op_expander_pass.cc +++ b/tensorflow/compiler/xla/service/op_expander_pass.cc @@ -36,6 +36,9 @@ StatusOr OpExpanderPass::Run(HloModule* module) { for (HloInstruction* inst : matching_instructions) { TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandInstruction(inst)); + if (expanded_root == nullptr) { + continue; + } TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); } diff --git a/tensorflow/compiler/xla/service/op_expander_pass.h b/tensorflow/compiler/xla/service/op_expander_pass.h index 794849d354b..276e3d70b8e 100644 --- a/tensorflow/compiler/xla/service/op_expander_pass.h +++ b/tensorflow/compiler/xla/service/op_expander_pass.h @@ -33,7 +33,9 @@ class OpExpanderPass : public HloModulePass { // Returns `true` if `instruction` should be expanded by this pass. virtual bool InstructionMatchesPattern(HloInstruction* instruction) = 0; - // Returns a replacement for `instruction`. + // Returns a replacement for `instruction`, or nullptr if no replacement is + // neeeded (e.g. only the to_apply subcomputation of the instruction was + // modified). virtual StatusOr ExpandInstruction( HloInstruction* instruction) = 0; }; diff --git a/tensorflow/compiler/xla/service/stable_sort_expander.cc b/tensorflow/compiler/xla/service/stable_sort_expander.cc new file mode 100644 index 00000000000..1aa7e5fe7c0 --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander.cc @@ -0,0 +1,204 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Looks for a iota operand that can be used as tie breaker in the computation. +// If no matching iota operand is found, a iota operand is added to Sort. The +// comparison computation is adjusted to break ties using the values from the +// iota operand. +StatusOr StableSortExpander::ExpandInstruction( + HloInstruction* instruction) { + auto* sort = Cast(instruction); + HloComputation* computation = sort->parent(); + + HloInstruction* expanded_sort = nullptr; + absl::flat_hash_set used_indices; + int64 iota_index = -1; + for (const HloInstruction* operand : sort->operands()) { + // We can only use the iota operand if it has an iota dimension which is the + // same as the dimension to sort. Also it should have an integral type that + // is large enough for the number of elements in the sort dimension. For + // now, we only allow S32, because we expect to find a S32 iota operand for + // all Sort ops which are created by TopK. + // TODO(b/122298745): Also support other types. + if (operand->opcode() == HloOpcode::kIota && + Cast(operand)->iota_dimension() == + sort->sort_dimension() && + operand->shape().element_type() == S32) { + iota_index = sort->operand_index(operand); + break; + } + } + + // If there is currently no iota operand which we could use for making the + // sort stable, we will have to add a new such operand. + if (iota_index == -1) { + Shape iota_shape = sort->operand(0)->shape(); + // We might need to use S64 if the number of elements in the sort dimension + // is bigger than 2^31 - 1. + // TODO(b/122298745): Handle Sort ops where S32 is too small for the number + // of elements in the sort dimension. + if (iota_shape.dimensions(sort->sort_dimension()) > + std::numeric_limits::max()) { + return Unimplemented( + "Stable sorting of more than 2^31-1 elements is not implemented"); + } + iota_shape.set_element_type(S32); + auto iota = computation->AddInstruction( + HloInstruction::CreateIota(iota_shape, sort->sort_dimension())); + + // Create a new comparator. + auto comparator = sort->to_apply(); + absl::flat_hash_map> + replacements; + std::vector> extra_parameters; + std::vector extra_parameter_ptrs; + Shape scalar_shape = ShapeUtil::MakeShape(S32, {}); + extra_parameters.push_back(HloInstruction::CreateParameter( + sort->operand_count() * 2, scalar_shape, + absl::StrCat("p.", sort->operand_count(), ".lhs"))); + extra_parameter_ptrs.push_back(extra_parameters.back().get()); + extra_parameters.push_back(HloInstruction::CreateParameter( + sort->operand_count() * 2 + 1, scalar_shape, + absl::StrCat("p.", sort->operand_count(), ".rhs"))); + extra_parameter_ptrs.push_back(extra_parameters.back().get()); + sort->set_to_apply(sort->GetModule()->AddEmbeddedComputation( + comparator->CloneWithReplacements(std::move(replacements), + extra_parameter_ptrs))); + + // Replace the original sort op. + std::vector new_operands(sort->operands().begin(), + sort->operands().end()); + new_operands.push_back(iota); + std::vector new_shapes = sort->operand_count() == 1 + ? std::vector{sort->shape()} + : sort->shape().tuple_shapes(); + new_shapes.push_back(iota_shape); + Shape new_sort_shape = ShapeUtil::MakeTupleShape(new_shapes); + HloInstruction* new_sort = computation->AddInstruction( + sort->CloneWithNewOperands(new_sort_shape, new_operands)); + + // Add a "wrapper" around the new sort op to make sure we have the same + // shape as before. For the rank 1 case, we only need a GetTupleElement, + // otherwise we create a Tuple consisting of GetTupleElements of the new + // sort. + std::vector tuple_elements; + tuple_elements.reserve(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + tuple_elements.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + sort->operand(i)->shape(), new_sort, i))); + } + expanded_sort = tuple_elements[0]; + if (tuple_elements.size() > 1) { + expanded_sort = computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); + } + sort = Cast(new_sort); + iota_index = sort->operand_count() - 1; + } + + // Modify the computation to break ties using the iota operand. + auto comparator = sort->to_apply(); + std::vector instructions_postorder = + comparator->MakeInstructionPostOrder(); + absl::flat_hash_map replacements; + // Look up instr in the replacements map, and return either the replacement, + // or instr, if the replacement isn't present. + auto replace = [&](HloInstruction* instr) { + auto it = replacements.find(instr); + if (it == replacements.end()) { + return instr; + } + return it->second; + }; + HloInstruction* old_root = comparator->root_instruction(); + // The comparison computation gets 2 * n parameters (n being the number of + // operands of Sort), where parameters 2 * i and 2 * i + 1 correspond to two + // different scalars of operand i of Sort which are to be compared. The + // comparison computation should induce a strict weak order, so if + // to_apply(p1.lhs, p1.rhs, ..., pn.lhs, pn.rhs) is equal to + // to_apply(p1.rhs, p1.lhs, ..., pn.rhs, pn.lhs), we can conclude that the + // values to be compared are equivalent, and perform a tie-breaker comparison. + // + // We clone each instruction with at least one operand, but use as new + // operands of the instruction the replacements of the original operands. + // Parameter 2 * i is replaced by parameter 2 * i + 1 and vice versa. This + // should make sure that the cloned root instruction gives the result of the + // comparison computation when being called with each scalar pair reversed. + // parameters corresponding to the iota operand. + for (int64 i = 0; i < comparator->num_parameters(); ++i) { + replacements[comparator->parameter_instruction(i)] = + comparator->parameter_instruction(i ^ 1); + } + HloInstruction* cloned_root = nullptr; + for (HloInstruction* inst : instructions_postorder) { + if (inst->operand_count() == 0) { + continue; + } + std::vector new_operands; + new_operands.reserve(inst->operand_count()); + for (HloInstruction* operand : inst->operands()) { + new_operands.push_back(replace(operand)); + } + auto new_instruction = + inst->CloneWithNewOperands(inst->shape(), new_operands); + replacements[inst] = new_instruction.get(); + if (inst == old_root) { + cloned_root = new_instruction.get(); + } + comparator->AddInstruction(std::move(new_instruction)); + } + CHECK_NE(cloned_root, nullptr); + Shape scalar_pred = ShapeUtil::MakeShape(PRED, {}); + HloInstruction* same = + comparator->AddInstruction(HloInstruction::CreateBinary( + scalar_pred, HloOpcode::kEq, old_root, cloned_root)); + HloInstruction* tie_breaker = + comparator->AddInstruction(HloInstruction::CreateBinary( + scalar_pred, HloOpcode::kLt, + comparator->parameter_instruction(2 * iota_index), + comparator->parameter_instruction(2 * iota_index + 1))); + HloInstruction* new_root = + comparator->AddInstruction(HloInstruction::CreateTernary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker, + old_root)); + comparator->set_root_instruction(new_root); + + return expanded_sort; +} + +bool StableSortExpander::InstructionMatchesPattern( + HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSort && + Cast(instruction)->is_stable(); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/stable_sort_expander.h b/tensorflow/compiler/xla/service/stable_sort_expander.h new file mode 100644 index 00000000000..31b6fd92d25 --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which expands Sort ops that have the is_stable field set to true +// into equivalent Sort ops which guarantee stable sorting without relying on +// the is_stable field. +class StableSortExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "stable-sort-expander"; } + + private: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/stable_sort_expander_test.cc b/tensorflow/compiler/xla/service/stable_sort_expander_test.cc new file mode 100644 index 00000000000..a62d953e6e8 --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander_test.cc @@ -0,0 +1,358 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" + +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace m = match; + +using StableSortExpanderTest = HloTestBase; + +// Checks whether 'a' and 'b' are roots of equivalent computations, except that +// parameters 2 * i and 2 * i + 1 are switched. +bool IsSameComputationExceptParams(const HloInstruction* a, + const HloInstruction* b) { + if (a->opcode() != b->opcode() || a->operand_count() != b->operand_count()) { + return false; + } + if (a->opcode() == HloOpcode::kParameter) { + // Check that parameters were switched. + return a->parameter_number() == (b->parameter_number() ^ 1); + } + // If the operation has no operands, it should actually be the same. + if (a->operand_count() == 0) { + return a == b; + } + // Otherwise recursively compare all operands. + for (int64 i = 0; i < a->operand_count(); ++i) { + if (!IsSameComputationExceptParams(a->operand(i), b->operand(i))) { + return false; + } + } + return true; +} + +// Check that the comparison computation has been modified to add a tie breaker +// using 'iota_parameter'. +void CheckComputationHasTieBreaker(const HloInstruction* root, + int64 iota_parameter) { + // With the tie breaker, the root instruction should be + // Select(Eq(Comp(), CompReverse()), Lt(), Comp()) + // with Comp() being the original comparison function, and CompReverse() being + // the copied comparison function where the parameters are reversed. Lt() is + // the tie breaker comparison using the Iota operand. + ASSERT_EQ(root->opcode(), HloOpcode::kSelect); + ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kEq); + + // Check that the tie breaker instruction is correct. + EXPECT_THAT(root->operand(1), + GmockMatch(m::Lt(m::Parameter(iota_parameter * 2), + m::Parameter(iota_parameter * 2 + 1)))); + EXPECT_EQ(root->operand(2), root->operand(0)->operand(0)); + + // Check that Comp() and CompReverse() are equivalent except that + // CompReverse() has reversed parameters. + EXPECT_TRUE(IsSameComputationExceptParams(root->operand(0)->operand(0), + root->operand(0)->operand(1))); +} + +TEST_F(StableSortExpanderTest, StabilizeSortReuseIotaOperand) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, + StabilizeSortReuseIotaOperandComplicatedComparison) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + max = u32[] constant(2147483647) + zero = s32[] constant(0) + lhs.signed = s32[] bitcast-convert(p.0.lhs) + lhs.unsigned = u32[] bitcast-convert(p.0.lhs) + lhs.flipped = u32[] subtract(max, lhs.unsigned) + lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped) + lhs.is_negative = pred[] less-than(lhs.flipped.signed, zero) + lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed) + rhs.signed = s32[] bitcast-convert(p.0.rhs) + rhs.unsigned = u32[] bitcast-convert(p.0.rhs) + rhs.flipped = u32[] subtract(max, rhs.unsigned) + rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped) + rhs.is_negative = pred[] less-than(rhs.flipped.signed, zero) + rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed) + ROOT lt = pred[] less-than(lhs.converted, rhs.converted) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, StabilizeSortAddIotaOperandAndChangeRoot) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + ROOT sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, GmockMatch(m::Tuple( + m::GetTupleElement( + m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 0), + m::GetTupleElement( + m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 1)))); + CheckComputationHasTieBreaker( + root->operand(0)->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, HonorIsStableFlag) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=false + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_FALSE(stabilizer.Run(module.get()).ValueOrDie()); +} + +TEST_F(StableSortExpanderTest, + StabilizeSortDontReuseIotaOperandWrongDimension) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=0 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + // Simplify away the "wrapper" tuple around the new sort. + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions( + [](const Shape&, const Shape&) { return false; })); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, StabilizeSortDontReuseIotaOperandWrongType) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = f32[] parameter(2) + p.1.rhs = f32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = f32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + // Simplify away the "wrapper" tuple around the new sort. + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions( + [](const Shape&, const Shape&) { return false; })); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, StabilizeSortR1) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + mask = s32[] constant(65535) + lhs = s32[] and(p.0.lhs, mask) + rhs = s32[] and(p.0.rhs, mask) + ROOT lt = pred[] less-than(lhs, rhs) + } + + ENTRY sort_computation { + keys = s32[64,8732]{1,0} parameter(0) + ROOT sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare, + is_stable=true + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, StabilizeSortR1NoRoot) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + mask = s32[] constant(65535) + lhs = s32[] and(p.0.lhs, mask) + rhs = s32[] and(p.0.rhs, mask) + ROOT lt = pred[] less-than(lhs, rhs) + } + + ENTRY sort_computation { + keys = s32[64,8732]{1,0} parameter(0) + sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare, + is_stable=true + ROOT neg = s32[64,8732]{1,0} negate(sort) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Negate(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0)))); + CheckComputationHasTieBreaker( + root->operand(0)->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/1); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 55160261392..6f61fc44166 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -1072,7 +1072,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); TF_ASSERT_OK_AND_ASSIGN( - auto* sort, MakeSortHlo(keys_shape, {keys}, 0, &builder, module_.get())); + auto* sort, MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, + &builder, module_.get())); computation_ = module_->AddEntryComputation(builder.Build()); RunAnalysis(); @@ -1094,7 +1095,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { TF_ASSERT_OK_AND_ASSIGN( auto* sort, MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}), - {keys, values}, 0, &builder, module_.get())); + {keys, values}, 0, /*is_stable=*/false, &builder, + module_.get())); computation_ = module_->AddEntryComputation(builder.Build()); RunAnalysis(); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index db1c9274690..a67aa6ebfe2 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1146,7 +1146,7 @@ xla_test( xla_test( name = "reduce_test", srcs = ["reduce_test.cc"], - shard_count = 40, + shard_count = 31, tags = [ "optonly", ], diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 11343ddfd0b..262b77264f5 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -1188,6 +1188,8 @@ std::vector GetEinsumTestCases() { p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"}, p{v{5, 6}, v{6, 7}, "ab,cd->dcba"}, p{v{6}, v{6, 7}, "b,bc->c"}, + p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc->ab"}, + p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba->ca"}, p{v{77}, v{77}, "a,a->a"}, p{v{77}, v{77, 55}, "a,ab->ba"}, p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"}, @@ -1265,5 +1267,51 @@ ENTRY %test { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); } +XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_1) { + // Tests for a caching bug in the XLA CPU backend. + absl::string_view hlo_string = + R"( +HloModule CpuTiledDotEmitterCachingBug + +ENTRY main { + lhs = f32[20,40] parameter(0) + rhs_0 = f32[40,1] parameter(2) + rhs_1 = f32[1,40] parameter(1) + + dot_0 = f32[20,1] dot(lhs, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_1 = f32[20,1] dot(lhs, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + + ROOT result = f32[20,1] divide(dot_0, dot_1) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + +XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_2) { + // Tests for a caching bug in the XLA CPU backend. + absl::string_view hlo_string = + R"( +HloModule CpuTiledDotEmitterCachingBug + +ENTRY main { + lhs_0 = f32[20,40] parameter(0) + rhs_0 = f32[40,1] parameter(1) + lhs_1 = f32[1,40] parameter(2) + rhs_1 = f32[20,40] parameter(3) + + dot_0 = f32[20,1] dot(lhs_0, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_1 = f32[1,20] dot(lhs_1, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + + dot_0_reshaped = f32[20] reshape(dot_0) + dot_1_reshaped = f32[20] reshape(dot_1) + + ROOT result = f32[20] divide(dot_0_reshaped, dot_1_reshaped) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index d9d54fd2556..0151981ef16 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -205,6 +205,17 @@ Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } +StatusOr> HloTestBase::ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64 num_replicas) { + HloRunner::ReplicatedExecuteOptions options; + options.num_replicas = num_replicas; + for (auto argument : arguments) { + options.arguments.push_back(argument); + } + return test_runner_.ExecuteReplicated(std::move(module), options); +} + StatusOr> HloTestBase::MakeReferenceModule( const HloModule& test_module, const std::function& reference_preprocessor) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 78bdd336e0a..3c2bcbb5df5 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -173,6 +173,11 @@ class HloTestBase : public ::testing::Test { Literal ExecuteAndTransfer(std::unique_ptr module, absl::Span arguments); + // Executes the given module on multiple replicas. + StatusOr> ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64 num_replicas); + // Executes the given hlo module on two backends and compares results. // // 'arguments': the input of the hlo module. diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index c743dfd32b3..cda2d7c7c6b 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -30,6 +30,11 @@ def xla_proto_library(name, srcs = [], deps = [], visibility = None, testonly = **kwargs ) +def xla_py_proto_library(**kwargs): + # Note: we don't currently define a proto library target for Python in OSS. + _ignore = kwargs + pass + def xla_py_grpc_library(**kwargs): # Note: we don't currently define any special targets for Python GRPC in OSS. _ignore = kwargs diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index 6a7f1065253..343f43b7159 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -122,6 +122,17 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") .HostMemory("literal"), XRTReadLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") + .Device(DEVICE_XLA_GPU) + .HostMemory("handles") + .HostMemory("tensors"), + XRTReadToTensorOp); +REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") + .Device(DEVICE_XLA_CPU) + .HostMemory("handles") + .HostMemory("tensors"), + XRTReadToTensorOp); + REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") .Device(DEVICE_XLA_GPU) .HostMemory("handle"), diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index e2c223b3dbb..6af73ecc853 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -40,6 +41,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -215,27 +217,29 @@ class XRTAllocateFromTensorOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple)); + std::vector minor_to_major; if (ctx->HasAttr("layouts")) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major)); } OP_REQUIRES( ctx, tf_shapes_.size() == dtypes_.size(), errors::InvalidArgument("shapes and dtypes must be the same length")); std::vector xla_shapes; + xla_shapes.reserve(tf_shapes_.size()); for (int i = 0; i < tf_shapes_.size(); i++) { xla::Shape xla_shape; OP_REQUIRES_OK( ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape)); - xla_shapes.push_back(xla_shape); + xla_shapes.push_back(std::move(xla_shape)); } if (xla_shapes.size() > 1 || make_tuple) { shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes); } else { shape_.Swap(&xla_shapes.front()); } - if (!minor_to_major_.empty()) { + if (!minor_to_major.empty()) { xla::Shape shape_with_layouts; - OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major_, + OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major, /*layout_func=*/nullptr, &shape_with_layouts)); shape_.Swap(&shape_with_layouts); @@ -304,7 +308,6 @@ class XRTAllocateFromTensorOp : public OpKernel { private: std::vector tf_shapes_; DataTypeVector dtypes_; - std::vector minor_to_major_; xla::Shape shape_; }; @@ -487,7 +490,7 @@ class XRTReadLiteralOp : public OpKernel { OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( ctx, allocation->device_ordinal(), &device_ref)); - xla::Literal literal; + xla::Literal literal(allocation->on_host_shape()); OP_REQUIRES_OK( ctx, allocation->ToLiteral(device_ref.backend(), device_ref.device_ordinal(), &literal)); @@ -499,6 +502,96 @@ class XRTReadLiteralOp : public OpKernel { } }; +// Op that reads a device-resident tuple to host memory and returns it as a +// literal. +template +class XRTReadToTensorOp : public OpKernel { + public: + explicit XRTReadToTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("release_handles", &discard_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); + } + ~XRTReadToTensorOp() override = default; + XRTReadToTensorOp(const XRTReadToTensorOp&) = delete; + XRTReadToTensorOp& operator=(const XRTReadToTensorOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTReadToTensorOp::Compute"; + + const Tensor& handle_tensor = ctx->input(0); + // TODO(phawkins,dlibenzi): accept multiple handles (i.e., vectors, not + // just scalars.) + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), + errors::Internal("computation input should be an int64 scalar")); + int64 allocation_handle = handle_tensor.scalar()(); + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + XRTTupleAllocation* allocation; + OP_REQUIRES_OK( + ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); + core::ScopedUnref allocation_unref(allocation); + + if (discard_) { + VLOG(2) << "Releasing handle " << allocation_handle; + OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager( + rm, allocation_handle)); + } + + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + class DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( + ctx, allocation->device_ordinal(), &device_ref)); + + xla::Shape shape = allocation->on_host_shape(); + int output = 0; + Status status = xla::ShapeUtil::ForEachMutableSubshapeWithStatus( + &shape, + [&](xla::Shape* subshape, const xla::ShapeIndex& index) -> Status { + if (subshape->IsTuple()) return Status::OK(); + + xla::PrimitiveType xla_type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType( + ctx->expected_output_dtype(output), &xla_type)); + if (xla_type != subshape->element_type()) { + return errors::InvalidArgument( + "Type mismatch between buffer type (", subshape->ToString(), + ") and tensor type (", + DataTypeString(ctx->expected_output_dtype(output)), + ") for output tensor ", output); + } + + TensorShape output_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(*subshape, &output_shape)); + + Tensor* output_tensor; + TF_RETURN_IF_ERROR( + ctx->allocate_output(output, output_shape, &output_tensor)); + + XRTTupleAllocation* sub; + TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( + allocation, index, &sub, /*alias_parent_allocation=*/true)); + core::ScopedUnref sub_unref(sub); + + xla::MutableBorrowingLiteral literal; + TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral( + xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor, + &literal)); + TF_RETURN_IF_ERROR(sub->ToLiteral( + device_ref.backend(), device_ref.device_ordinal(), &literal)); + + ++output; + return Status::OK(); + }); + OP_REQUIRES_OK(ctx, status); + } + bool discard_; + DataTypeVector dtypes_; +}; + // Op that writes a new literal value into device-resident memory. template class XRTWriteLiteralOp : public OpKernel { diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index 2e743fec496..8832270fb27 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -151,6 +151,27 @@ releases the handle. 'literal' is a serialized xla::LiteralProto proto. )"); +REGISTER_OP("XRTReadToTensor") + .Input("handles: int64") + .Attr("release_handles: bool = False") + .Attr("dtypes: list(type)") + .Output("tensors: dtypes") + .SetShapeFn(tensorflow::shape_inference::UnknownShape) + .Doc( + R"( +Copies allocated values from device memory and returns them as zero or more +Tensors. If a handle refers to a non-tuple buffer, a single tensor is returned. +In general, the tensors returned for a handle correspond to an in-order traversal +of a the tuple-tree value referenced by the handle. + +'handles' contains ids returned from Ops that produced on-device allocations. +At present, only a single (scalar) handle is supported. +'dtypes' are the expected types for each `Tensor` to be returned. If the +expected and actual tensor types do not match, an error is returned. +'release_handles': if True, `handles` are released. +'tensors' are the output Tensors. +)"); + REGISTER_OP("XRTReleaseAllocationHandle") .Input("handle: int64") .SetShapeFn(tensorflow::shape_inference::NoOutputs) diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 78a1b6afc05..1b3bcbea4c1 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -220,7 +220,7 @@ XRTTupleAllocation::~XRTTupleAllocation() { } Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, - xla::Literal* literal) { + xla::MutableLiteralBase* literal) { auto transfer_manager = backend->transfer_manager(); TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); @@ -234,9 +234,8 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, " has been released"); } } - TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice( - stream.get(), shaped_buffer)); - return Status::OK(); + return transfer_manager->TransferLiteralFromDevice(stream.get(), + shaped_buffer, *literal); } Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend, diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index ddf2656e6f5..6519da30d02 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -147,7 +147,7 @@ class XRTTupleAllocation : public ResourceBase { // Copies the allocation from device to host and returns it in literal. Status ToLiteral(xla::Backend* backend, int device_ordinal, - xla::Literal* literal); + xla::MutableLiteralBase* literal); // Write a new literal value to the allocation. Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 25f2640e35a..0173b8bb064 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -218,7 +218,6 @@ cc_library( "//tensorflow/contrib/tensor_forest:stats_ops_op_lib", "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/text:all_ops", - "//tensorflow/contrib/tpu:all_ops", ] + select({ "//tensorflow:android": [], "//tensorflow:ios": [], diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index 47d910d42a2..5a8b2ba9caf 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -399,8 +399,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): def testQuantileRegression(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.max_tree_depth = 6 + learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -413,7 +413,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): model_upper = estimator.GradientBoostedDecisionTreeQuantileRegressor( quantiles=[0.95], learner_config=learner_config, - num_trees=100, + num_trees=12, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) @@ -428,31 +428,12 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_below_upper >= 0.92) self.assertTrue(frac_below_upper <= 0.98) - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() - model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.fit(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["scores"]) - - frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower >= 0.92) - self.assertTrue(frac_above_lower <= 0.98) - # Multi-dimensional quantile regression. def testQuantileRegressionMultiDimLabel(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.max_tree_depth = 6 + learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -467,7 +448,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): quantiles=[0.95], learner_config=learner_config, label_dimension=2, - num_trees=100, + num_trees=18, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) @@ -490,35 +471,6 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_both_below_upper >= 0.91) self.assertTrue(frac_both_below_upper <= 0.99) - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( - two_dimension=True) - model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - label_dimension=2, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.fit(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["scores"]) - - count_above_lower = np.count_nonzero(lower < y, axis=0) - count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) - frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) - frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) - frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower_0 >= 0.92) - self.assertTrue(frac_above_lower_0 <= 0.98) - self.assertTrue(frac_above_lower_1 >= 0.92) - self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.91) - self.assertTrue(frac_both_above_lower <= 0.99) - class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): @@ -712,11 +664,12 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): est.evaluate(input_fn=input_fn, steps=1) est.predict(input_fn=input_fn) - # One dimensional quantile regression. - def testQuantileRegression(self): + # Quantile regression in core is the same as in non core estimator, so we + # just check that it does not fail. + def testQuantileRegressionDoesNotThroughException(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 + learner_config.constraints.max_tree_depth = 1 learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -731,112 +684,12 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( quantiles=[0.95], learner_config=learner_config, - num_trees=100, + num_trees=1, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) model_upper.train(input_fn=train_input_fn, steps=1000) result_iter = model_upper.predict(input_fn=test_input_fn) - upper = [] - for prediction_dict in result_iter: - upper.append(prediction_dict["predictions"]) - - frac_below_upper = round(1. * np.count_nonzero(upper > y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_below_upper >= 0.92) - self.assertTrue(frac_below_upper <= 0.98) - - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() - model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.train(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["predictions"]) - - frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower >= 0.92) - self.assertTrue(frac_above_lower <= 0.98) - - # Multi-dimensional quantile regression. - def testQuantileRegressionMultiDimLabel(self): - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE - learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.tree_complexity = ( - 1.0 / _QUANTILE_REGRESSION_SIZE) - - train_input_fn, test_input_fn, y = _quantile_regression_input_fns( - two_dimension=True) - y = y.reshape(_QUANTILE_REGRESSION_SIZE, 2) - - # 95% percentile. - model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.95], - learner_config=learner_config, - num_trees=100, - label_dimension=2, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_upper.train(input_fn=train_input_fn, steps=1000) - result_iter = model_upper.predict(input_fn=test_input_fn) - upper = [] - for prediction_dict in result_iter: - upper.append(prediction_dict["predictions"]) - - count_below_upper = np.count_nonzero(upper > y, axis=0) - count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1)) - frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3) - frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3) - frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3) - # +/- 3% - self.assertTrue(frac_below_upper_0 >= 0.92) - self.assertTrue(frac_below_upper_0 <= 0.98) - self.assertTrue(frac_below_upper_1 >= 0.92) - self.assertTrue(frac_below_upper_1 <= 0.98) - self.assertTrue(frac_both_below_upper >= 0.91) - self.assertTrue(frac_both_below_upper <= 0.99) - - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( - two_dimension=True) - model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - label_dimension=2, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.train(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["predictions"]) - - count_above_lower = np.count_nonzero(lower < y, axis=0) - count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) - frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) - frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) - frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower_0 >= 0.92) - self.assertTrue(frac_above_lower_0 <= 0.98) - self.assertTrue(frac_above_lower_1 >= 0.92) - self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.91) - self.assertTrue(frac_both_above_lower <= 0.99) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/input_lib_test.py b/tensorflow/contrib/distribute/python/input_lib_test.py index 10a58316ec5..204f52b034f 100644 --- a/tensorflow/contrib/distribute/python/input_lib_test.py +++ b/tensorflow/contrib/distribute/python/input_lib_test.py @@ -22,7 +22,6 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib @@ -214,33 +213,5 @@ class InputIteratorMultiWorkerTest( expected_values, sess) -class SplitDatasetBatchTest(test.TestCase): - - def testBatchDataset(self): - dataset = dataset_ops.Dataset.range(100).batch(20) - split_batch_by = 2 - result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - def testMapAndBatchDataset(self): - dataset = dataset_ops.Dataset.range(100) - dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20)) - split_batch_by = 2 - result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - def testPrefetchDataset(self): - dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1) - split_batch_by = 2 - result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 2eca1d1877f..cc9cee31bef 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.ops import array_ops from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.summary.writer import writer_cache @@ -68,6 +69,20 @@ def simple_functional_model(): return model +def simple_subclassed_model(num_labels=_NUM_CLASS): + + class _SimpleMLP(keras.Model): + + def __init__(self, num_labels): + super(_SimpleMLP, self).__init__() + self.dense = keras.layers.Dense(num_labels) + + def call(self, inputs): + return self.dense(inputs) + + return _SimpleMLP(num_labels) + + def simple_multi_inputs_multi_outputs_model(): input_a = keras.layers.Input(shape=(16,), name='input_a') input_b = keras.layers.Input(shape=(16,), name='input_b') @@ -1184,5 +1199,109 @@ class TestDistributionStrategyWithDatasets(test.TestCase, atol=1e-4, rtol=1e-4) +class TestRegularizerLoss(test.TestCase, parameterized.TestCase): + class IdentityRegularizer(keras.regularizers.Regularizer): + + def __call__(self, x): + return array_ops.identity(x) + + class AddLayer(keras.layers.Layer): + + def build(self, _): + self.v = self.add_weight( + 'v', (), initializer='ones', + regularizer=TestRegularizerLoss.IdentityRegularizer()) + + def call(self, inputs): + return inputs + self.v + + @staticmethod + def loss_fn(_, y_pred): + return y_pred + + @combinations.generate(all_strategy_combinations_minus_default()) + def test_regularizer_loss(self, distribution): + batch_size = 2 + if not distributed_training_utils.global_batch_size_supported(distribution): + batch_size //= distribution.num_replicas_in_sync + + # Given an input x, which is always 1, and variable v, this model computes + # Loss=x+v+regularizer_loss, where regularizer_loss=v and the variable is + # initialized to 1. Therefore, this model computes Loss=1+2v, and so the + # gradient dLoss/dv = 2. This gradient of 2 is averaged over all examples + # in a batch and then multiplied by the learning rate of 1. As a result, + # the model update for one batch should subtract 2 from v, resulting in v + # being -1. If the regularizer loss is not scaled correctly by number of + # replicas, the variable value will be incorrect when number of replicas + # >1. For e.g. it will be -2 if num replicas = 2. + with distribution.scope(): + x = keras.layers.Input(shape=(), batch_size=batch_size) + y = TestRegularizerLoss.AddLayer()(x) + model = keras.models.Model(inputs=x, outputs=y) + opt = gradient_descent_keras.SGD(1.) + model.compile(opt, loss=TestRegularizerLoss.loss_fn) + model.fit(x=np.array([1., 1.], dtype=np.float32), + y=np.array([1., 1.], dtype=np.float32), + batch_size=batch_size) + v = model.get_weights()[0] + self.assertEqual(-1.0, v) + + +class TestDistributionStrategyWithKerasModels(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations()) + def test_distribution_strategy_on_sequential_model(self, distribution): + with distribution.scope(): + model = simple_sequential_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) + + inputs = np.zeros((20, 10), np.float32) + targets = np.zeros((20, 2), np.float32) + + model.fit(inputs, targets, epochs=1, steps_per_epoch=2) + model.predict(inputs, steps=1) + model.evaluate(inputs, targets, steps=1) + + @combinations.generate(all_strategy_combinations()) + def test_distribution_strategy_on_functional_model(self, distribution): + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + model.fit(inputs, targets, epochs=1, steps_per_epoch=2) + model.predict(inputs, steps=1) + model.evaluate(inputs, targets, steps=1) + + # TODO(b/124377929): Remove error assertions once subclassed models + # are supported in DistributedStrategy. + @combinations.generate(all_strategy_combinations_minus_default()) + def test_distribution_strategy_on_subclassed_model(self, distribution): + with distribution.scope(): + model = simple_subclassed_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 2), dtype=np.float32) + + with self.assertRaisesRegexp(AttributeError, 'has no attribute'): + model.fit(inputs, targets, epochs=1, steps_per_epoch=2) + + with self.assertRaisesRegexp(AttributeError, 'has no attribute'): + model.predict(inputs, steps=1) + + with self.assertRaisesRegexp(AttributeError, 'has no attribute'): + model.evaluate(inputs, targets, steps=1) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index db0868fb2c4..386e4cf69b7 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -377,7 +377,10 @@ py_test( name = "classifier_metrics_test", srcs = ["python/eval/python/classifier_metrics_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = [ + "no_pip", + "no_windows", + ], deps = [ ":classifier_metrics", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD index d319aa7986d..92016e6a839 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD @@ -19,16 +19,25 @@ tf_cc_binary( "//tensorflow/core:array_ops_op_lib", "//tensorflow/core:candidate_sampling_ops_op_lib", "//tensorflow/core:control_flow_ops_op_lib", + "//tensorflow/core:data_flow_ops_op_lib", "//tensorflow/core:framework_internal", "//tensorflow/core:functional_ops_op_lib", + "//tensorflow/core:io_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:list_ops_op_lib", + "//tensorflow/core:logging_ops_op_lib", + "//tensorflow/core:lookup_ops_op_lib", "//tensorflow/core:manip_ops_op_lib", "//tensorflow/core:math_ops_op_lib", "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", + "//tensorflow/core:parsing_ops_op_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:random_ops_op_lib", "//tensorflow/core:remote_fused_graph_ops_op_lib", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:sparse_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:string_ops_op_lib", "//tensorflow/core:training_ops_op_lib", "//tensorflow/core:user_ops_op_lib", diff --git a/tensorflow/contrib/memory_stats/BUILD b/tensorflow/contrib/memory_stats/BUILD index 63843b993c1..93701249cc8 100644 --- a/tensorflow/contrib/memory_stats/BUILD +++ b/tensorflow/contrib/memory_stats/BUILD @@ -10,6 +10,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -45,6 +46,28 @@ tf_gen_op_wrapper_py( deps = [":memory_stats_ops_op_lib"], ) +tf_gen_op_wrapper_cc( + name = "memory_stats_ops", + out_ops_file = "memory_stats_ops", +) + +cc_library( + name = "memory_stats_cc", + srcs = ["memory_stats_ops.cc"], + hdrs = ["memory_stats_ops.h"], + visibility = ["//visibility:public"], + deps = [ + ":memory_stats_kernels", + ":memory_stats_ops_op_lib", + "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + tf_custom_op_py_library( name = "memory_stats_py", srcs = [ diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 294dbddcb5e..d580ca6eb6d 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -23,17 +23,13 @@ package( ], ) -cc_library( - name = "all_ops", +py_library( + name = "tpu_py", + srcs = ["python/ops/tpu_ops.py"], + srcs_version = "PY2AND3", deps = [ - ":cross_replica_ops_op_lib", - ":heartbeat_ops_op_lib", - ":host_compute_ops_op_lib", - ":infeed_ops_op_lib", - ":outfeed_ops_op_lib", - ":replication_ops_op_lib", - ":tpu_configuration_ops_op_lib", - ":tpu_embedding_ops_op_lib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:tpu_ops_gen", ], ) @@ -75,7 +71,6 @@ py_library( ":functional", ":tpu_embedding", ":tpu_lib", - ":tpu_ordinal_selector_py", "//tensorflow/contrib/training:training_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -98,122 +93,15 @@ py_library( ], ) -tf_gen_op_libs( - op_lib_names = [ - "cross_replica_ops", - "heartbeat_ops", - "host_compute_ops", - "infeed_ops", - "outfeed_ops", - "replication_ops", - "tpu_configuration_ops", - "tpu_embedding_ops", - "tpu_ordinal_selector_op", - "functional_ops", - ], - deps = [ - "//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils", - "//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils", - "//tensorflow/core:lib", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", - ], -) - -tf_custom_op_library( - name = "python/ops/_tpu_ops.so", - srcs = [ - "ops/cross_replica_ops.cc", - "ops/heartbeat_ops.cc", - "ops/host_compute_ops.cc", - "ops/infeed_ops.cc", - "ops/outfeed_ops.cc", - "ops/replication_ops.cc", - "ops/tpu_configuration_ops.cc", - "ops/tpu_embedding_ops.cc", - ], - deps = [ - "//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils", - "//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", - ], -) - -tf_gen_op_wrapper_py( - name = "tpu_ops", - hidden = [ - "SendTPUEmbeddingGradients", - "EnqueueTPUEmbeddingIntegerBatch", - "EnqueueTPUEmbeddingSparseBatch", - "EnqueueTPUEmbeddingSparseTensorBatch", - ], - deps = [ - ":cross_replica_ops_op_lib", - ":heartbeat_ops_op_lib", - ":host_compute_ops_op_lib", - ":infeed_ops_op_lib", - ":outfeed_ops_op_lib", - ":replication_ops_op_lib", - ":tpu_configuration_ops_op_lib", - ":tpu_embedding_ops_op_lib", - ], -) - -tf_custom_op_library( - name = "python/ops/_tpu_ordinal_selector_op.so", - srcs = ["ops/tpu_ordinal_selector_op.cc"], -) - -tf_custom_op_py_library( - name = "tpu_ordinal_selector_py", - srcs = ["python/ops/tpu_ordinal_selector_op.py"], - dso = [":python/ops/_tpu_ordinal_selector_op.so"], - kernels = [ - ":tpu_ordinal_selector_op_op_lib", - ], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - ":tpu_ordinal_selector_op", - ], -) - -tf_gen_op_wrapper_py( - name = "tpu_ordinal_selector_op", - deps = [ - ":tpu_ordinal_selector_op_op_lib", - ], -) - -tf_custom_op_library( - name = "python/ops/_functional_ops.so", - srcs = ["ops/functional_ops.cc"], -) - -tf_gen_op_wrapper_py( - name = "gen_functional_ops", - out = "python/tpu/gen_functional_ops.py", - hidden = [ - "TPUPartitionedCall", - ], - deps = [":functional_ops_op_lib"], -) - -tf_custom_op_py_library( +py_library( name = "functional", srcs = ["python/tpu/functional.py"], - dso = [":python/ops/_functional_ops.so"], - kernels = [ - ":functional_ops_op_lib", - ], srcs_version = "PY2AND3", visibility = [ "//visibility:public", ], deps = [ - ":gen_functional_ops", + "//tensorflow/python:tpu_ops_gen", ], ) @@ -229,26 +117,6 @@ py_library( ], ) -tf_custom_op_py_library( - name = "tpu_py", - srcs = ["python/ops/tpu_ops.py"], - dso = [":python/ops/_tpu_ops.so"], - kernels = [ - ":all_ops", - ], - srcs_version = "PY2AND3", - deps = [ - ":profiler", - ":tpu_ops", - "//tensorflow/contrib/compiler:xla", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform", - "//tensorflow/python:util", - ], -) - py_library( name = "tpu", srcs = [ @@ -327,7 +195,6 @@ py_library( ":datasets", ":functional", ":profiler", - ":tpu_ordinal_selector_py", ":tpu_py", "//tensorflow/compiler/xla/experimental/xla_sharding", "//tensorflow/compiler/xla/python_api:xla_shape", @@ -347,6 +214,7 @@ py_library( "//tensorflow/python:framework", "//tensorflow/python:framework_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python:tpu_ops_gen", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", @@ -470,13 +338,13 @@ py_library( srcs_version = "PY2AND3", deps = [ ":tpu_lib", - ":tpu_ops", "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", "//tensorflow/python:partitioned_variables", + "//tensorflow/python:tpu_ops_gen", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "@six_archive//:six", diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index 541fbf33a30..7ad30c61e42 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -3,8 +3,8 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_cc") load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos") +load("//tensorflow/core:platform/default/build_config.bzl", "tf_profiler_all_protos") tf_proto_library( name = "tpu_profiler_proto", @@ -12,7 +12,7 @@ tf_proto_library( has_services = 1, cc_api_version = 2, cc_grpc_version = 1, - protodeps = [":op_profile_proto"] + tf_additional_all_protos(), + protodeps = tf_profiler_all_protos() + tf_additional_all_protos(), visibility = ["//visibility:public"], ) @@ -22,13 +22,13 @@ cc_library( hdrs = ["dump_tpu_profile.h"], visibility = ["//visibility:public"], deps = [ - ":op_profile_proto_cc", ":tpu_profiler_proto_cc", ":trace_events_proto_cc", ":trace_events_to_json", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler:protos_all_cc", ], ) @@ -45,15 +45,10 @@ tf_cc_binary( ], visibility = ["//visibility:public"], deps = [ - ":dump_tpu_profile", - ":tpu_profiler_analysis_proto_cc", - ":tpu_profiler_proto_cc", ":version", - "//tensorflow:grpc++", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "//tensorflow/core/platform/cloud:gcs_file_system", + "//tensorflow/core/profiler/rpc/client:capture_profile", ], ) @@ -87,13 +82,6 @@ tf_cc_test( ], ) -tf_proto_library( - name = "op_profile_proto", - srcs = ["op_profile.proto"], - cc_api_version = 2, - visibility = ["//visibility:public"], -) - tf_proto_library( name = "tpu_profiler_analysis_proto", srcs = ["tpu_profiler_analysis.proto"], diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index 1c5ea2d997a..508b929658d 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -18,235 +18,11 @@ limitations under the License. // Initiates a TPU profiling on the TPUProfiler service at service_addr, // receives and dumps the profile data to a tensorboard log directory. -#include "grpcpp/grpcpp.h" - -#include -#include -#include - -#include "tensorflow/contrib/tpu/profiler/dump_tpu_profile.h" -#include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h" -#include "tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.grpc.pb.h" #include "tensorflow/contrib/tpu/profiler/version.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/profiler/rpc/client/capture_profile.h" #include "tensorflow/core/util/command_line_flags.h" -namespace tensorflow { -namespace tpu { -namespace { - -using ::tensorflow::TPUProfileAnalysis; -using ::tensorflow::TPUProfiler; - -constexpr uint64 kMaxEvents = 1000000; - -string GetCurrentTimeStampAsString() { - char s[128]; - std::time_t t = std::time(nullptr); - CHECK_NE(std::strftime(s, sizeof(s), "%F_%T", std::localtime(&t)), 0); - return s; -} - -Status ValidateHostPortPair(const string& host_port) { - uint32 port; - std::vector parts = str_util::Split(host_port, ':'); - // Must be host:port, port must be a number, host must not contain a '/', - // host also must not be empty. - if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) || - parts[0].find("/") != string::npos || parts[0].empty()) { - return errors::InvalidArgument("Could not interpret \"", host_port, - "\" as a host-port pair."); - } - return Status::OK(); -} - -ProfileRequest PopulateProfileRequest(int duration_ms, - const string& repository_root, - const string& session_id, - const ProfileOptions& opts) { - ProfileRequest request; - request.set_duration_ms(duration_ms); - request.set_max_events(kMaxEvents); - if (tensorflow::str_util::StartsWith(repository_root, "gs://")) { - // For backward compatibilities, only generate tracetable etc when the - // user provide a GCS path for model directory. - request.set_repository_root(repository_root); - request.set_session_id(session_id); - } - request.add_tools("op_profile"); - request.add_tools("input_pipeline"); - request.add_tools("memory_viewer"); - request.add_tools("overview_page"); - *request.mutable_opts() = opts; - return request; -} - -// Returns whether the returned trace is empty. -// Failure are handled by CHECK, i.e. abort() -bool Profile(const string& service_addr, const string& logdir, int duration_ms, - const string& repository_root, const string& session_id, - const ProfileOptions& opts) { - ProfileRequest request = - PopulateProfileRequest(duration_ms, repository_root, session_id, opts); - - ::grpc::ClientContext context; - ::grpc::ChannelArguments channel_args; - // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their - // `ValidateHostPortPair` checks for empty host string case. - channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, - std::numeric_limits::max()); - std::unique_ptr stub = - TPUProfiler::NewStub(::grpc::CreateCustomChannel( - "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), - channel_args)); - ProfileResponse response; - TF_QCHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response))); - - if (!response.encoded_trace().empty()) { - TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile( - logdir, session_id, "", response, &std::cout)); - // Print this at the end so that it's not buried in irrelevant LOG messages. - std::cout - << "NOTE: using the trace duration " << duration_ms << "ms." - << std::endl - << "Set an appropriate duration (with --duration_ms) if you " - "don't see a full step in your trace or the captured trace is too " - "large." - << std::endl; - } - - return response.encoded_trace().empty(); -} - -// Start a new profiling session that include all the hosts included in -// hostnames, for the time interval of duration_ms. Possibly save the profiling -// result in the directory specified by repository_root and session_id. -bool NewSession(const string& service_addr, - const std::vector& hostnames, - int duration_ms, const string& repository_root, - const string& session_id, const ProfileOptions& opts) { - NewProfileSessionRequest new_session_request; - *new_session_request.mutable_request() = - PopulateProfileRequest(duration_ms, repository_root, session_id, opts); - new_session_request.set_repository_root(repository_root); - new_session_request.set_session_id(session_id); - for (const auto& hostname : hostnames) { - new_session_request.add_hosts(hostname); - } - - ::grpc::ClientContext context; - ::grpc::ChannelArguments channel_args; - // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their - // `ValidateHostPortPair` checks for empty host string case. - channel_args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - // TODO(jiesun): GRPC support following relevant naming scheme: - // 1. dns:///host:port - // 2. ipv4:host:port or ipv6:[host]:port - // We might need to change the prefix which depends on what TPU name resolver - // will give us. - std::unique_ptr stub = - TPUProfileAnalysis::NewStub(::grpc::CreateCustomChannel( - "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), - channel_args)); - NewProfileSessionResponse new_session_response; - TF_QCHECK_OK(FromGrpcStatus( - stub->NewSession(&context, new_session_request, &new_session_response))); - - std::cout << "Profile session succeed for host(s):" - << str_util::Join(hostnames, ",") << std::endl; - return new_session_response.empty_trace(); -} - -// Starts tracing on a single or multiple TPU hosts and saves the result in the -// given logdir. If no trace was collected, retries tracing for -// num_tracing_attempts. -void StartTracing(const tensorflow::string& service_addr, - const tensorflow::string& logdir, - const tensorflow::string& workers_list, - bool include_dataset_ops, int duration_ms, - int num_tracing_attempts) { - // Use the current timestamp as the run name. - tensorflow::string session_id = GetCurrentTimeStampAsString(); - constexpr char kProfilePluginDirectory[] = "plugins/profile/"; - tensorflow::string repository_root = - io::JoinPath(logdir, kProfilePluginDirectory); - std::vector hostnames = - tensorflow::str_util::Split(workers_list, ","); - - bool empty_trace = false; - int remaining_attempts = num_tracing_attempts; - tensorflow::ProfileOptions opts; - opts.set_include_dataset_ops(include_dataset_ops); - while (true) { - std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. " - << "Remaining attempt(s): " << remaining_attempts-- << std::endl; - if (hostnames.empty()) { - empty_trace = tensorflow::tpu::Profile(service_addr, logdir, duration_ms, - repository_root, session_id, opts); - } else { - tensorflow::string tpu_master = service_addr; - empty_trace = - tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms, - repository_root, session_id, opts); - } - if (remaining_attempts <= 0 || !empty_trace) break; - std::cout << "No trace event is collected. Automatically retrying." - << std::endl - << std::endl; - } - - if (empty_trace) { - std::cout << "No trace event is collected after " << num_tracing_attempts - << " attempt(s). " - << "Perhaps, you want to try again (with more attempts?)." - << std::endl - << "Tip: increase number of attempts with --num_tracing_attempts." - << std::endl; - } -} - -MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level) { - MonitorRequest request; - request.set_duration_ms(duration_ms); - request.set_monitoring_level(monitoring_level); - return request; -} - -// Repeatedly collects profiles and shows user-friendly metrics for -// 'num_queries' time(s). -void StartMonitoring(const tensorflow::string& service_addr, int duration_ms, - int monitoring_level, int num_queries) { - for (int query = 0; query < num_queries; ++query) { - MonitorRequest request = - PopulateMonitorRequest(duration_ms, monitoring_level); - - ::grpc::ClientContext context; - ::grpc::ChannelArguments channel_args; - channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, - std::numeric_limits::max()); - std::unique_ptr stub = - TPUProfiler::NewStub(::grpc::CreateCustomChannel( - "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), - channel_args)); - MonitorResponse response; - TF_QCHECK_OK(FromGrpcStatus(stub->Monitor(&context, request, &response))); - - std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1 - << "):\n\n" - << response.data() << std::flush; - } -} - -} // namespace -} // namespace tpu -} // namespace tensorflow - int main(int argc, char** argv) { tensorflow::string FLAGS_service_addr; tensorflow::string FLAGS_logdir; @@ -301,7 +77,7 @@ int main(int argc, char** argv) { return 2; } tensorflow::Status status = - tensorflow::tpu::ValidateHostPortPair(FLAGS_service_addr); + tensorflow::profiler::client::ValidateHostPortPair(FLAGS_service_addr); if (!status.ok()) { std::cout << status.error_message() << std::endl; std::cout << usage.c_str() << std::endl; @@ -324,12 +100,12 @@ int main(int argc, char** argv) { << FLAGS_service_addr << " for " << duration_ms << "ms and show metrics for " << num_queries << " time(s)." << std::endl; - tensorflow::tpu::StartMonitoring(FLAGS_service_addr, duration_ms, - FLAGS_monitoring_level, num_queries); + tensorflow::profiler::client::StartMonitoring( + FLAGS_service_addr, duration_ms, FLAGS_monitoring_level, num_queries); } else { - tensorflow::tpu::StartTracing(FLAGS_service_addr, FLAGS_logdir, - FLAGS_workers_list, FLAGS_include_dataset_ops, - duration_ms, num_tracing_attempts); + tensorflow::profiler::client::StartTracing( + FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list, + FLAGS_include_dataset_ops, duration_ms, num_tracing_attempts); } return 0; } diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc index b4b06a40a2c..e6e355cca57 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tpu/profiler/op_profile.pb.h" #include "tensorflow/contrib/tpu/profiler/trace_events.pb.h" #include "tensorflow/contrib/tpu/profiler/trace_events_to_json.h" #include "tensorflow/core/lib/core/errors.h" @@ -29,6 +28,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/profiler/op_profile.pb.h" #include "tensorflow/core/util/events_writer.h" namespace tensorflow { @@ -88,7 +88,7 @@ Status DumpTraceToLogDirectory(StringPiece run_dir, const string& host_prefix, Status DumpOpProfileToLogDirectory(StringPiece run_dir, const string& host_prefix, - const tpu::op_profile::Profile& profile, + const profiler::op_profile::Profile& profile, std::ostream* os) { string path = JoinPath(run_dir, StrCat(host_prefix, kJsonOpProfileFileName)); string json; diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto index da4a95e0450..299af06b38a 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto @@ -3,7 +3,7 @@ package tensorflow; import "tensorflow/core/framework/graph.proto"; import "tensorflow/core/protobuf/config.proto"; -import "tensorflow/contrib/tpu/profiler/op_profile.proto"; +import "tensorflow/core/profiler/op_profile.proto"; // The TPUProfiler service retrieves performance information about // the programs running on connected TPUs over a period of time. @@ -96,7 +96,7 @@ message ProfileResponse { // Assembles a hierarchical performance profile based on HLOs in trace events. // If the trace covers multiple programs, the longest-running one is analyzed. // See op_profile.proto for the detailed semantics of the returned profile. - tpu.op_profile.Profile op_profile = 4; + profiler.op_profile.Profile op_profile = 4; // Data payload for each required tools. repeated ProfileToolData tool_data = 6; diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 55f7c6bcbc1..2320306ba9b 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -28,16 +28,10 @@ from tensorflow.python.platform import tf_logging as logging if platform.system() != "Windows": # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.contrib.tpu.ops import gen_tpu_ops - from tensorflow.contrib.tpu.ops.gen_tpu_ops import * - - from tensorflow.contrib.util import loader - from tensorflow.python.platform import resource_loader + from tensorflow.python.ops import gen_tpu_ops + from tensorflow.python.ops.gen_tpu_ops import * # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - _tpu_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("_tpu_ops.so")) - def _create_default_group_assignment(): num_shards = tpu_function.get_tpu_context().number_of_shards if num_shards is None: @@ -237,12 +231,12 @@ if platform.system() != "Windows": """ if learning_rates is None: learning_rates = [] - return gen_tpu_ops._send_tpu_embedding_gradients( + return gen_tpu_ops.send_tpu_embedding_gradients( inputs=inputs, learning_rates=learning_rates, config=config, name=name) send_tpu_embedding_gradients.__doc__ = ( - gen_tpu_ops._send_tpu_embedding_gradients.__doc__) + gen_tpu_ops.send_tpu_embedding_gradients.__doc__) # pylint: disable=protected-access def enqueue_tpu_embedding_integer_batch(batch, @@ -268,14 +262,14 @@ if platform.system() != "Windows": """ if mode_override is None: mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_integer_batch( + return gen_tpu_ops.enqueue_tpu_embedding_integer_batch( batch=batch, device_ordinal=device_ordinal, mode_override=mode_override, name=name) enqueue_tpu_embedding_integer_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_integer_batch.__doc__) + gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__) # pylint: disable=protected-access def enqueue_tpu_embedding_sparse_batch(sample_indices, @@ -317,7 +311,7 @@ if platform.system() != "Windows": """ if mode_override is None: mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_sparse_batch( + return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch( sample_indices=sample_indices, embedding_indices=embedding_indices, aggregation_weights=aggregation_weights, @@ -327,7 +321,7 @@ if platform.system() != "Windows": name=name) enqueue_tpu_embedding_sparse_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_sparse_batch.__doc__) + gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__) # pylint: disable=protected-access def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices, @@ -375,7 +369,7 @@ if platform.system() != "Windows": """ if mode_override is None: mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_sparse_tensor_batch( + return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( sample_indices=sample_indices, embedding_indices=embedding_indices, aggregation_weights=aggregation_weights, @@ -386,7 +380,7 @@ if platform.system() != "Windows": name=name) enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_sparse_tensor_batch.__doc__) + gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__) else: # We have already built the appropriate libraries into the binary via CMake diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py b/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py index 5ca38cd1bae..6917ac2e1a7 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py @@ -23,15 +23,12 @@ import platform if platform.system() != "Windows": # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.contrib.tpu.ops.gen_tpu_ordinal_selector_op import * + from tensorflow.python.ops.gen_tpu_ops import tpu_ordinal_selector from tensorflow.contrib.util import loader from tensorflow.python.platform import resource_loader # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - _tpu_ordinal_selector_op = loader.load_op_library( - resource_loader.get_path_to_datafile("_tpu_ordinal_selector_op.so")) - else: # We have already built the appropriate libraries into the binary via CMake # if we have built contrib, so we don't need this diff --git a/tensorflow/contrib/tpu/python/tpu/functional.py b/tensorflow/contrib/tpu/python/tpu/functional.py index 24c85156e53..3d04c64033b 100644 --- a/tensorflow/contrib/tpu/python/tpu/functional.py +++ b/tensorflow/contrib/tpu/python/tpu/functional.py @@ -18,22 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import platform +from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import gen_functional_ops - - -TPUPartitionedCall = gen_functional_ops._tpu_partitioned_call # pylint: disable=invalid-name,protected-access - - -if platform.system() != "Windows": - # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.contrib.tpu.ops.gen_tpu_ordinal_selector_op import * - - from tensorflow.contrib.util import loader - from tensorflow.python.platform import resource_loader - # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - - _tpu_partitioned_call_op = loader.load_op_library( - resource_loader.get_path_to_datafile("../ops/_functional_ops.so") - ) +TPUPartitionedCall = tpu_ops.tpu_partitioned_call # pylint: disable=invalid-name diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py index eb99a18d839..fcad7b29726 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py @@ -25,7 +25,6 @@ import re import six from tensorflow.contrib.framework.python.framework import experimental -from tensorflow.contrib.tpu.ops import gen_tpu_ops from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 @@ -35,6 +34,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_tpu_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 4f761e3599b..afe0a04d3b5 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -32,7 +32,6 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.ops import tpu_ordinal_selector_op from tensorflow.contrib.tpu.python.tpu import _tpu_estimator_embedding from tensorflow.contrib.tpu.python.tpu import error_handling from tensorflow.contrib.tpu.python.tpu import functional as tpu_functional @@ -1364,13 +1363,13 @@ def call_computation(computation, # TPU core with every `Session.run()` call. Note that the entire inference # graph executes on a single core, and that invocations of this graph # will round-robin among the cores attached to a host. - @function.Defun() + @function.Defun(capture_resource_var_by_value=False) def tpu_subgraph(): return computation() return tpu_functional.TPUPartitionedCall( args=tpu_subgraph.captured_inputs, - device_ordinal=tpu_ordinal_selector_op.tpu_ordinal_selector(), + device_ordinal=tpu_ops.tpu_ordinal_selector(), Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg], f=tpu_subgraph) else: @@ -2465,8 +2464,14 @@ class TPUEstimator(estimator_lib.Estimator): device_assignment = ctx.device_assignment else: device_assignment = None - tensors_on_cpu = tpu.rewrite_for_inference( - tpu_computation, device_assignment=device_assignment) + + if self._experimental_exported_model_uses_all_cores: + tensors_on_cpu = tpu.rewrite( + tpu_computation, device_assignment=device_assignment) + else: + tensors_on_cpu = tpu.rewrite_for_inference( + tpu_computation, device_assignment=device_assignment) + (estimator_spec, export_outputs_dict, export_outputs_list, predictions_dict) = ( tpu_capture.get()) diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD index 07dbd5ca8d6..ada08f95ae4 100644 --- a/tensorflow/contrib/util/BUILD +++ b/tensorflow/contrib/util/BUILD @@ -22,7 +22,9 @@ cc_library( "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:sendrecv_ops_op_lib", "//tensorflow/core:tensorflow", "//tensorflow/core/kernels:immutable_constant_op", ], diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index cc242d0e3c9..fb93e8ddd3c 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1168,6 +1168,29 @@ tf_gen_op_libs( deps = [":lib"], ) +tf_gen_op_libs( + op_lib_names = [ + "tpu_configuration_ops", + "tpu_cross_replica_ops", + "tpu_embedding_ops", + "tpu_functional_ops", + "tpu_heartbeat_ops", + "tpu_host_compute_ops", + "tpu_infeed_ops", + "tpu_outfeed_ops", + "tpu_ordinal_selector_ops", + "tpu_replication_ops", + ], + deps = [ + ":lib", + ":lib_proto_parsing", + ":protos_all_cc", + "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", + "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", + "//tensorflow/core/tpu:tpu_embedding_output_layout_utils", + ], +) + # And one for all user ops cc_library( name = "user_ops_op_lib", @@ -1284,6 +1307,16 @@ cc_library( ":state_ops_op_lib", ":stateless_random_ops_op_lib", ":string_ops_op_lib", + ":tpu_configuration_ops_op_lib", + ":tpu_cross_replica_ops_op_lib", + ":tpu_embedding_ops_op_lib", + ":tpu_functional_ops_op_lib", + ":tpu_heartbeat_ops_op_lib", + ":tpu_host_compute_ops_op_lib", + ":tpu_infeed_ops_op_lib", + ":tpu_outfeed_ops_op_lib", + ":tpu_ordinal_selector_ops_op_lib", + ":tpu_replication_ops_op_lib", ":training_ops_op_lib", ":user_ops_op_lib", ":word2vec_ops", @@ -1551,6 +1584,7 @@ cc_library( ":framework_internal", ":lib", ":lib_internal", + ":ops", ":protos_all_cc", ":shape_inference_testutil", ":tensor_testutil", @@ -1897,6 +1931,7 @@ filegroup( "**/*testutil*", "**/*testlib*", "**/*main.cc", + "**/tpu_*", ], ), visibility = ["//visibility:public"], @@ -2969,6 +3004,8 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/rendezvous_mgr.h", "common_runtime/rendezvous_util.h", "common_runtime/ring_reducer.h", + "common_runtime/ring_alg.h", + "common_runtime/ring_gatherer.h", "common_runtime/session_factory.h", "common_runtime/single_threaded_cpu_device.h", "common_runtime/stats_publisher_interface.h", @@ -3025,6 +3062,8 @@ tf_cuda_library( "common_runtime/renamed_device.cc", "common_runtime/rendezvous_mgr.cc", "common_runtime/rendezvous_util.cc", + "common_runtime/ring_alg.cc", + "common_runtime/ring_gatherer.cc", "common_runtime/ring_reducer.cc", "common_runtime/session.cc", "common_runtime/session_factory.cc", @@ -3936,7 +3975,6 @@ tf_cc_test( "ops/cudnn_rnn_ops_test.cc", ], deps = [ - ":cudnn_rnn_ops", "//tensorflow/core", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -3996,6 +4034,35 @@ tf_cc_tests_gpu( ], ) +tf_cc_tests_gpu( + name = "ring_gatherer_test", + size = "medium", + srcs = [ + "common_runtime/ring_gatherer_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":all_kernels", + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + ":framework", + ":framework_internal", + ":gpu_runtime", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":protos_test_cc", + ":test", + ":test_main", + ":testlib", + "@com_google_absl//absl/memory", + ], +) + tf_cc_tests_gpu( name = "hierarchical_tree_broadcaster_test", size = "medium", diff --git a/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt b/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt new file mode 100644 index 00000000000..d6f28bd022b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt @@ -0,0 +1,67 @@ +op { + graph_op_name: "AllToAll" + in_arg { + name: "input" + description: <