Revert "Branch 175277161"

This commit is contained in:
Martin Wicke 2017-11-10 12:26:11 -08:00 committed by GitHub
parent 047d7965d2
commit d0a5d885d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
542 changed files with 11012 additions and 20928 deletions

View File

@ -42,7 +42,7 @@ The Code of Conduct also applies within project spaces and in public spaces when
Conflicts in an open source project can take many forms, from someone having a bad day and using harsh and hurtful language in the issue queue, to more serious instances such as sexist/racist statements or threats of violence, and everything in between. Conflicts in an open source project can take many forms, from someone having a bad day and using harsh and hurtful language in the issue queue, to more serious instances such as sexist/racist statements or threats of violence, and everything in between.
If the behaviour is threatening or harassing, or for other reasons requires immediate escalation, please see below. If the behavior is threatening or harassing, or for other reasons requires immediate escalation, please see below.
However, for the vast majority of issues, we aim to empower individuals to first resolve conflicts themselves, asking for help when needed, and only after that fails to escalate further. This approach gives people more control over the outcome of their dispute. However, for the vast majority of issues, we aim to empower individuals to first resolve conflicts themselves, asking for help when needed, and only after that fails to escalate further. This approach gives people more control over the outcome of their dispute.

View File

@ -73,11 +73,11 @@ $ python
## For more information ## For more information
* [TensorFlow website](https://www.tensorflow.org) * [TensorFlow Website](https://www.tensorflow.org)
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
* [TensorFlow Model Zoo](https://github.com/tensorflow/models) * [TensorFlow Model Zoo](https://github.com/tensorflow/models)
* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) * [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
* [TensorFlow course at Stanford](https://web.stanford.edu/class/cs20si) * [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate. Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate.

View File

@ -43,6 +43,7 @@ _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
_TF_OPENCL_VERSION = '1.2' _TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
def is_windows(): def is_windows():
@ -487,11 +488,10 @@ def set_cc_opt_flags(environ_cp):
cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS', cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS',
question, default_cc_opt_flags) question, default_cc_opt_flags)
for opt in cc_opt_flags.split(): for opt in cc_opt_flags.split():
write_to_bazelrc('build:opt --cxxopt=%s --copt=%s' % (opt, opt)) host_opt = '-march=native' # It should be safe on the same build host.
host_opt = '-march=native' # It should be safe on the same build host. write_to_bazelrc(
write_to_bazelrc( 'build:opt --cxxopt=%s --copt=%s' % (opt, opt) +
'build:opt --host_cxxopt=%s --host_copt=%s' % (host_opt, host_opt)) ' --host_cxxopt=%s --host_copt=%s' % (host_opt, host_opt))
write_to_bazelrc('build:opt --define with_default_optimizations=true')
def set_tf_cuda_clang(environ_cp): def set_tf_cuda_clang(environ_cp):
@ -641,7 +641,7 @@ def set_tf_cuda_version(environ_cp):
write_action_env_to_bazelrc('TF_CUDA_VERSION', tf_cuda_version) write_action_env_to_bazelrc('TF_CUDA_VERSION', tf_cuda_version)
def set_tf_cunn_version(environ_cp): def set_tf_cudnn_version(environ_cp):
"""Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION.""" """Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION."""
ask_cudnn_version = ( ask_cudnn_version = (
'Please specify the cuDNN version you want to use. ' 'Please specify the cuDNN version you want to use. '
@ -887,6 +887,27 @@ def set_computecpp_toolkit_path(environ_cp):
write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH', write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH',
computecpp_toolkit_path) computecpp_toolkit_path)
def set_trisycl_include_dir(environ_cp):
"""Set TRISYCL_INCLUDE_DIR"""
ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
'include directory. (Use --config=sycl_trisycl '
'when building with Bazel) '
'[Default is %s]: '
) % (_DEFAULT_TRISYCL_INCLUDE_DIR)
while True:
trisycl_include_dir = get_from_env_or_user_or_default(
environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir,
_DEFAULT_TRISYCL_INCLUDE_DIR)
if os.path.exists(trisycl_include_dir):
break
print('Invalid triSYCL include directory, %s cannot be found'
% (trisycl_include_dir))
# Set TRISYCL_INCLUDE_DIR
environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR',
trisycl_include_dir)
def set_mpi_home(environ_cp): def set_mpi_home(environ_cp):
"""Set MPI_HOME.""" """Set MPI_HOME."""
@ -999,6 +1020,8 @@ def main():
environ_cp['TF_NEED_GCP'] = '0' environ_cp['TF_NEED_GCP'] = '0'
environ_cp['TF_NEED_HDFS'] = '0' environ_cp['TF_NEED_HDFS'] = '0'
environ_cp['TF_NEED_JEMALLOC'] = '0' environ_cp['TF_NEED_JEMALLOC'] = '0'
environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
environ_cp['TF_NEED_COMPUTECPP'] = '0'
environ_cp['TF_NEED_OPENCL'] = '0' environ_cp['TF_NEED_OPENCL'] = '0'
environ_cp['TF_NEED_S3'] = '0' environ_cp['TF_NEED_S3'] = '0'
environ_cp['TF_CUDA_CLANG'] = '0' environ_cp['TF_CUDA_CLANG'] = '0'
@ -1021,17 +1044,21 @@ def main():
set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support', set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support',
False, 'verbs') False, 'verbs')
set_action_env_var(environ_cp, 'TF_NEED_OPENCL', 'OpenCL', False) set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False)
if environ_cp.get('TF_NEED_OPENCL') == '1': if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
set_host_cxx_compiler(environ_cp) set_host_cxx_compiler(environ_cp)
set_host_c_compiler(environ_cp) set_host_c_compiler(environ_cp)
set_computecpp_toolkit_path(environ_cp) set_action_env_var(environ_cp, 'TF_NEED_COMPUTECPP', 'ComputeCPP', True)
if environ_cp.get('TF_NEED_COMPUTECPP') == '1':
set_computecpp_toolkit_path(environ_cp)
else:
set_trisycl_include_dir(environ_cp)
set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False) set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)
if (environ_cp.get('TF_NEED_CUDA') == '1' and if (environ_cp.get('TF_NEED_CUDA') == '1' and
'TF_CUDA_CONFIG_REPO' not in environ_cp): 'TF_CUDA_CONFIG_REPO' not in environ_cp):
set_tf_cuda_version(environ_cp) set_tf_cuda_version(environ_cp)
set_tf_cunn_version(environ_cp) set_tf_cudnn_version(environ_cp)
set_tf_cuda_compute_capabilities(environ_cp) set_tf_cuda_compute_capabilities(environ_cp)
set_tf_cuda_clang(environ_cp) set_tf_cuda_clang(environ_cp)

View File

@ -54,6 +54,15 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting(
name = "raspberry_pi_armeabi",
values = {
"crosstool_top": "@local_config_arm_compiler//:toolchain",
"cpu": "armeabi",
},
visibility = ["//visibility:public"],
)
config_setting( config_setting(
name = "android_arm", name = "android_arm",
values = { values = {
@ -110,7 +119,7 @@ config_setting(
config_setting( config_setting(
name = "no_tensorflow_py_deps", name = "no_tensorflow_py_deps",
define_values = {"no_tensorflow_py_deps": "true"}, values = {"define": "no_tensorflow_py_deps=true"},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
@ -166,122 +175,55 @@ config_setting(
# TODO(jhseu): Enable on other platforms other than Linux. # TODO(jhseu): Enable on other platforms other than Linux.
config_setting( config_setting(
name = "with_jemalloc_linux_x86_64", name = "with_jemalloc_linux_x86_64",
define_values = {"with_jemalloc": "true"}, values = {
values = {"cpu": "k8"}, "cpu": "k8",
"define": "with_jemalloc=true",
},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( config_setting(
name = "with_jemalloc_linux_ppc64le", name = "with_jemalloc_linux_ppc64le",
define_values = {"with_jemalloc": "true"}, values = {
values = {"cpu": "ppc"}, "cpu": "ppc",
visibility = ["//visibility:public"], "define": "with_jemalloc=true",
) },
config_setting(
name = "with_default_optimizations",
define_values = {"with_default_optimizations": "true"},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( config_setting(
name = "with_gcp_support", name = "with_gcp_support",
define_values = {"with_gcp_support": "true"}, values = {"define": "with_gcp_support=true"},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( config_setting(
name = "with_hdfs_support", name = "with_hdfs_support",
define_values = {"with_hdfs_support": "true"}, values = {"define": "with_hdfs_support=true"},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( config_setting(
name = "with_s3_support", name = "with_s3_support",
define_values = {"with_s3_support": "true"}, values = {"define": "with_s3_support=true"},
visibility = ["//visibility:public"],
)
# Crosses between platforms and file system libraries not supported on those
# platforms due to limitations in nested select() statements.
config_setting(
name = "with_gcp_support_windows_override",
define_values = {"with_gcp_support": "true"},
values = {"cpu": "x64_windows"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_hdfs_support_windows_override",
define_values = {"with_hdfs_support": "true"},
values = {"cpu": "x64_windows"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_s3_support_windows_override",
define_values = {"with_s3_support": "true"},
values = {"cpu": "x64_windows"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_gcp_support_android_override",
define_values = {"with_gcp_support": "true"},
values = {"crosstool_top": "//external:android/crosstool"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_hdfs_support_android_override",
define_values = {"with_hdfs_support": "true"},
values = {"crosstool_top": "//external:android/crosstool"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_s3_support_android_override",
define_values = {"with_s3_support": "true"},
values = {"crosstool_top": "//external:android/crosstool"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_gcp_support_ios_override",
define_values = {"with_gcp_support": "true"},
values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_hdfs_support_ios_override",
define_values = {"with_hdfs_support": "true"},
values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_s3_support_ios_override",
define_values = {"with_s3_support": "true"},
values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( config_setting(
name = "with_xla_support", name = "with_xla_support",
define_values = {"with_xla_support": "true"}, values = {"define": "with_xla_support=true"},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( config_setting(
name = "with_gdr_support", name = "with_gdr_support",
define_values = {"with_gdr_support": "true"}, values = {"define": "with_gdr_support=true"},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( config_setting(
name = "with_verbs_support", name = "with_verbs_support",
define_values = {"with_verbs_support": "true"}, values = {"define": "with_verbs_support=true"},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
@ -355,7 +297,7 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# Make a dummy rule that we can change "default" in select statements to. # Make a dummy rule that we can chaqnge "default" in select statements to.
# to disable dependencies in copybara. # to disable dependencies in copybara.
config_setting( config_setting(
name = "dummy_disabled_internal", name = "dummy_disabled_internal",
@ -384,6 +326,14 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"], visibility = ["//tensorflow:__subpackages__"],
) )
py_library(
name = "tensorflow_py",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/python"],
)
filegroup( filegroup(
name = "all_opensource_files", name = "all_opensource_files",
data = [ data = [
@ -737,11 +687,3 @@ tf_cc_shared_object(
"//tensorflow/core:tensorflow", "//tensorflow/core:tensorflow",
], ],
) )
py_library(
name = "tensorflow_py",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/python"],
)

View File

@ -890,8 +890,8 @@ const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
TF_Status* status) { TF_Status* status) {
const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name); const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
if (attr == nullptr) { if (attr == nullptr) {
status->status = InvalidArgument("Operation '", oper->node.name(), status->status =
"' has no attr named '", attr_name, "'."); InvalidArgument("Operation has no attr named '", attr_name, "'.");
} }
return attr; return attr;
} }

View File

@ -383,7 +383,7 @@ TEST(CAPI, Graph) {
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)); EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s)); ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s));
EXPECT_EQ(string("Operation 'feed' has no attr named 'missing'."), EXPECT_EQ(string("Operation has no attr named 'missing'."),
string(TF_Message(s))); string(TF_Message(s)));
// Make a constant oper with the scalar "3". // Make a constant oper with the scalar "3".
@ -1054,7 +1054,7 @@ class CApiColocationTest : public ::testing::Test {
TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_); TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_);
if (expected.empty()) { if (expected.empty()) {
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
EXPECT_EQ(std::string("Operation 'add' has no attr named '_class'."), EXPECT_EQ(std::string("Operation has no attr named '_class'."),
std::string(TF_Message(s_))); std::string(TF_Message(s_)));
return; return;
} }

View File

@ -39,7 +39,6 @@ tf_cuda_library(
tf_cuda_library( tf_cuda_library(
name = "c_api_internal", name = "c_api_internal",
hdrs = ["c_api_internal.h"], hdrs = ["c_api_internal.h"],
visibility = ["//tensorflow:internal"],
deps = [ deps = [
":c_api", ":c_api",
":runtime", ":runtime",
@ -106,6 +105,7 @@ tf_cc_test(
cc_library( cc_library(
name = "tape", name = "tape",
srcs = ["tape.cc"],
hdrs = ["tape.h"], hdrs = ["tape.h"],
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = [

102
tensorflow/c/eager/tape.cc Normal file
View File

@ -0,0 +1,102 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/tape.h"
namespace tensorflow {
namespace eager {
bool GradientTape::ShouldRecord(gtl::ArraySlice<int64> tensor_ids) {
for (int64 i : tensor_ids) {
if (tensor_tape_.find(i) != tensor_tape_.end()) {
return true;
}
}
return false;
}
void GradientTape::Watch(int64 tensor_id) {
tensor_tape_.emplace(tensor_id, -1);
}
void GradientTape::RecordOperation(
const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
gtl::ArraySlice<int64> input_tensor_id, void* backward_function,
const std::function<void()>& backward_function_deleter) {
if (!ShouldRecord(input_tensor_id)) {
backward_function_deleter();
return;
}
std::vector<int64> ids;
ids.reserve(input_tensor_id.size());
for (int64 i : input_tensor_id) {
tensor_usage_[i]++;
ids.push_back(i);
}
const int64 op_id = next_op_id_++;
std::vector<TapeTensor> tensors;
tensors.reserve(output_tensors.size());
for (const TapeTensor& o : output_tensors) {
// Note: the tensor can have already been watched and hence be in the tape,
// so we cannot check that we're inserting it here.
tensor_tape_[o.id] = op_id;
tensor_usage_[o.id] = 1;
tensors.push_back(o);
}
op_tape_[op_id] = OpTapeEntry{op_type, tensors, ids, backward_function,
backward_function_deleter};
}
void GradientTape::DeleteTrace(int64 tensor_id) {
auto it = tensor_usage_.find(tensor_id);
if (it == tensor_usage_.end()) {
return;
}
it->second--;
if (it->second != 0) {
return;
}
tensor_usage_.erase(it);
auto tensor_op_it = tensor_tape_.find(tensor_id);
if (tensor_op_it == tensor_tape_.end()) {
return;
}
const int64 op_id = tensor_op_it->second;
if (op_id == -1) {
// Do not delete watched tensors.
return;
}
tensor_tape_.erase(tensor_op_it);
auto op_it = op_tape_.find(op_id);
CHECK(op_it != op_tape_.end());
for (const auto& output : op_it->second.output_tensor_info) {
if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
// Found a usage for an output, so cannot delete the op.
return;
}
}
for (int64 id : op_it->second.input_tensor_id) {
DeleteTrace(id);
}
op_it->second.backward_function_deleter();
op_tape_.erase(op_it);
}
std::pair<TensorTape, OpTape> GradientTape::Export() {
return {std::move(tensor_tape_), std::move(op_tape_)};
}
} // namespace eager
} // namespace tensorflow

View File

@ -19,7 +19,6 @@ limitations under the License.
// maintains the data structures required to do so. // maintains the data structures required to do so.
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
@ -37,14 +36,13 @@ struct TapeTensor {
}; };
// Represents an entry in the tape. // Represents an entry in the tape.
template <typename BackwardFunction>
struct OpTapeEntry { struct OpTapeEntry {
string op_type; string op_type;
std::vector<TapeTensor> output_tensor_info; std::vector<TapeTensor> output_tensor_info;
std::vector<int64> input_tensor_id; std::vector<int64> input_tensor_id;
// TODO(apassos) consider narrowing down this interface. // TODO(apassos) consider narrowing down this interface.
BackwardFunction* backward_function; void* backward_function;
// Should be called before deleting the backward function. TODO(apassos) use // Should be called before deleting the backward function. TODO(apassos) use
// unique_ptrs to ensure this happens. // unique_ptrs to ensure this happens.
@ -57,68 +55,13 @@ struct OpTapeEntry {
using TensorTape = std::unordered_map<int64, int64>; using TensorTape = std::unordered_map<int64, int64>;
// Map from operation-id to tape entry. // Map from operation-id to tape entry.
template <typename BackwardFunction> using OpTape = std::unordered_map<int64, OpTapeEntry>;
using OpTape = std::unordered_map<int64, OpTapeEntry<BackwardFunction>>;
// Operations the tape needs to perform on tensors to do backpropagation. Named
// "vspace" because a subset of these are related to a vector space, such as
// adding gradients, getting zeroes, etc. Currently cannot be implemented
// without using tensorflow python code, hence left unspecified here.
//
// Gradient is the type returned by gradient functions. In Python TF it's either
// Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need
// to allow their size to be computed and they need to be passable to a backward
// function and deleted (as the backprop code creates lots of gradients the user
// is not interested in).
//
// BackwardFunction needs to be a closure which stores intermediate activations
// from the forward computation and calls a vector-jacobian product function
// (also known as adjoint function) to compute, given downstream gradients,
// upstream gradients.
//
// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
// specialization, which is blocked by quite a few things needing to loop back
// into python now.
template <typename Gradient, typename BackwardFunction>
class VSpace {
public:
virtual ~VSpace() {}
// Returns the number of elements in the gradient tensor.
virtual int64 NumElements(Gradient* tensor) const = 0;
// Consumes references to the tensors in the gradient_tensors list and returns
// a tensor with the result.
virtual Gradient* AggregateGradients(
gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
// Returns a tensor of the right shape and dtype filled with zeros.
virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0;
// Returns a Tensor which is filled with ones and like the input.
virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0;
// Calls the passed-in backward function.
virtual Status CallBackwardFunction(
BackwardFunction* backward_function,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) const = 0;
// Deletes the input tensor.
virtual void DeleteGradient(Gradient* gradient) const = 0;
};
// Traces the execution of operations, doing eager garbage collection, and // Traces the execution of operations, doing eager garbage collection, and
// exporting a full trace so other code can do backpropagation. Not thread-safe. // exporting a full trace so other code can do backpropagation. Not thread-safe.
template <typename Gradient, typename BackwardFunction>
class GradientTape { class GradientTape {
public: public:
GradientTape() {} GradientTape() {}
~GradientTape() {
for (const auto& pair : op_tape_) {
pair.second.backward_function_deleter();
}
}
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids); bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids);
@ -127,24 +70,19 @@ class GradientTape {
void RecordOperation(const string& op_type, void RecordOperation(const string& op_type,
gtl::ArraySlice<TapeTensor> output_tensors, gtl::ArraySlice<TapeTensor> output_tensors,
gtl::ArraySlice<int64> input_tensor_id, gtl::ArraySlice<int64> input_tensor_id,
BackwardFunction* backward_function, void* backward_function,
const std::function<void()>& backward_function_deleter); const std::function<void()>& backward_function_deleter);
void DeleteTrace(int64 tensor_id); void DeleteTrace(int64 tensor_id);
// Consumes the internal state of the tape (so cannot be called more than // Note: it is only valid to call Export once per tape, and after calling
// once) and produces the gradient of the target tensors with respect to the // export the tape is no longer valid (i.e. calls to ShouldRecord, Watch,
// source tensors. The output gradients are used if not empty and not // Record, and Delete have undefined behavior).
// null. The result is populated with one tensor per target element. std::pair<TensorTape, OpTape> Export();
Status ComputeGradient(const VSpace<Gradient, BackwardFunction>& vspace,
gtl::ArraySlice<int64> target_tensor_ids,
gtl::ArraySlice<int64> source_tensor_id,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result);
private: private:
TensorTape tensor_tape_; TensorTape tensor_tape_;
OpTape<BackwardFunction> op_tape_; OpTape op_tape_;
int64 next_op_id_{0}; int64 next_op_id_{0};
// Map from tensor id to number of remaining usages (i.e. how many entries in // Map from tensor id to number of remaining usages (i.e. how many entries in
@ -152,429 +90,6 @@ class GradientTape {
std::unordered_map<int64, int64> tensor_usage_; std::unordered_map<int64, int64> tensor_usage_;
}; };
// Template instantiations here
template <typename Gradient, typename BackwardFunction>
bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
gtl::ArraySlice<int64> tensor_ids) {
for (int64 i : tensor_ids) {
if (tensor_tape_.find(i) != tensor_tape_.end()) {
return true;
}
}
return false;
}
template <typename Gradient, typename BackwardFunction>
void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
tensor_tape_.emplace(tensor_id, -1);
}
template <typename Gradient, typename BackwardFunction>
void GradientTape<Gradient, BackwardFunction>::RecordOperation(
const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
const std::function<void()>& backward_function_deleter) {
if (!ShouldRecord(input_tensor_id)) {
backward_function_deleter();
return;
}
std::vector<int64> ids;
ids.reserve(input_tensor_id.size());
for (int64 i : input_tensor_id) {
tensor_usage_[i]++;
ids.push_back(i);
}
const int64 op_id = next_op_id_++;
std::vector<TapeTensor> tensors;
tensors.reserve(output_tensors.size());
for (const TapeTensor& o : output_tensors) {
// Note: the tensor can have already been watched and hence be in the tape,
// so we cannot check that we're inserting it here.
tensor_tape_[o.id] = op_id;
tensor_usage_[o.id] = 1;
tensors.push_back(o);
}
op_tape_[op_id] = OpTapeEntry<BackwardFunction>{
op_type, tensors, ids, backward_function, backward_function_deleter};
}
template <typename Gradient, typename BackwardFunction>
void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
auto it = tensor_usage_.find(tensor_id);
if (it == tensor_usage_.end()) {
return;
}
it->second--;
if (it->second != 0) {
return;
}
tensor_usage_.erase(it);
auto tensor_op_it = tensor_tape_.find(tensor_id);
if (tensor_op_it == tensor_tape_.end()) {
return;
}
const int64 op_id = tensor_op_it->second;
if (op_id == -1) {
// Do not delete watched tensors.
return;
}
tensor_tape_.erase(tensor_op_it);
auto op_it = op_tape_.find(op_id);
CHECK(op_it != op_tape_.end());
for (const auto& output : op_it->second.output_tensor_info) {
if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
// Found a usage for an output, so cannot delete the op.
return;
}
}
for (int64 id : op_it->second.input_tensor_id) {
DeleteTrace(id);
}
op_it->second.backward_function_deleter();
op_tape_.erase(op_it);
}
// Terminology:
//
// - op: a possibly composite operation, which has an entry in the tape
// - target: dy in dx/dy
// - source: dx in dx/dy
// - tensor: one of the many inputs or outputs of an operation
//
// Below here we do the gradient algorithm. It works as follows:
//
// First we filter the tape to just the subset of operations we want to
// differentiate. In the process of doing so we count how many times each Tensor
// is used as an input to an op (so we know when we're done computing gradients
// for that Tensor). We also count, for each tape entry, how many of its output
// Tensors need gradients to be computed (Tensors which are not used do not need
// any gradients to be computed).
//
// Finally, we start a backprop stack with a set of tape entries for which we
// have all gradients available. This set usually is a subset of the set of
// targets (not all since targets which have outputs in the tape will not have
// gradients available initially).
//
// Then we repeatedly pop an entry from the stack, run its backprop, and update
// the gradients of its inputs. Once we have computed all gradients for a single
// input we can mark this input as done, and this can trigger adding an entry to
// the stack if all outputs of that entry are now done.
//
// When the stack is empty we have gradients for all tensors we're interested
// in.
namespace {
template <typename BackwardFunction>
struct BackpropInitialState {
OpTape<BackwardFunction> op_tape;
// Map from tensor ID to how many references still exist for this tensor in
// the tape.
std::unordered_map<int64, int64> tensor_usage_counts;
// Maps from op ID to how many output tensors of this op still need to have
// their gradients computed.
std::unordered_map<int64, int64> op_missing_tensor;
};
template <typename BackwardFunction>
BackpropInitialState<BackwardFunction> PrepareBackprop(
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
OpTape<BackwardFunction> op_tape,
const std::unordered_set<int64>& sources_set) {
std::vector<int64> tensor_stack;
tensor_stack.reserve(target.size());
for (auto t : target) {
tensor_stack.push_back(t);
}
BackpropInitialState<BackwardFunction> result;
while (!tensor_stack.empty()) {
int64 tensor_id = tensor_stack.back();
tensor_stack.pop_back();
auto op_id_it = tensor_tape.find(tensor_id);
if (op_id_it == tensor_tape.end()) {
continue;
}
int64 op_id = op_id_it->second;
auto op_it = op_tape.find(op_id);
auto result_op_it = result.op_tape.find(op_id);
if (op_id == -1 || op_it == op_tape.end() ||
result_op_it != result.op_tape.end()) {
continue;
}
CHECK(result.op_tape.emplace(op_id, op_it->second).second);
for (auto it : op_it->second.input_tensor_id) {
auto count_it = result.tensor_usage_counts.find(it);
if (count_it != result.tensor_usage_counts.end()) {
count_it->second++;
} else {
result.tensor_usage_counts[it] = 1;
if (sources_set.find(it) == sources_set.end() &&
tensor_tape.find(it) != tensor_tape.end()) {
tensor_stack.push_back(it);
}
}
}
op_tape.erase(op_it);
}
for (auto& pair : result.tensor_usage_counts) {
auto it = tensor_tape.find(pair.first);
if (it != tensor_tape.end() && it->second != -1) {
result.op_missing_tensor[it->second] += 1;
}
}
// Call destructors for all unneeded gradient functions.
for (const auto& op_pair : op_tape) {
op_pair.second.backward_function_deleter();
}
return result;
}
template <typename BackwardFunction>
std::vector<int64> InitialStack(
const OpTape<BackwardFunction>& op_tape,
const std::unordered_map<int64, int64>& op_missing_tensor) {
std::vector<int64> result;
for (auto& op_entry : op_tape) {
if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
result.push_back(op_entry.first);
}
}
return result;
}
template <typename Gradient, typename BackwardFunction>
Status InitialGradients(
const VSpace<Gradient, BackwardFunction>& vspace,
gtl::ArraySlice<int64> target_tensor_ids,
gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
const OpTape<BackwardFunction>& op_tape,
const std::unordered_map<int64, int64>& tensor_usage_counts,
std::unordered_map<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
if (!output_gradients.empty() && output_gradients[i] != nullptr) {
// TODO(apassos) figure out how to print debugging information here.
return errors::InvalidArgument(
"A gradient was provided for a tensor which is used as part of the "
"computation.");
}
} else {
if (output_gradients.empty() || output_gradients[i] == nullptr) {
auto tensor_it = tensor_tape.find(id);
if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
auto op_it = op_tape.find(tensor_it->second);
if (op_it == op_tape.end()) {
return errors::Internal(
"Internal state of the gradient tape is invalid.");
}
bool found = false;
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
if (op_it->second.output_tensor_info[j].id == id) {
found = true;
(*result)[id].push_back(
vspace.Ones(op_it->second.output_tensor_info[j].shape,
op_it->second.output_tensor_info[j].dtype));
break;
}
}
if (!found) {
return errors::Internal(
"Internal state of the gradient tape is invalid.");
}
} else {
// No record of the target tensor found on the tape, so no gradient
// needs to be computed from it. Do nothing.
}
} else {
(*result)[id].push_back(output_gradients[i]);
}
}
}
return Status::OK();
}
} // namespace
// If over kMinAggregateCount gradients are accumulated and the total
// memory consumption is over kMinAggregateBytes, do an early aggregation
// so as to release the gradient tensor to save memory.
constexpr int kMinAggregateCount = 4;
constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
template <typename Gradient, typename BackwardFunction>
Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
const VSpace<Gradient, BackwardFunction>& vspace,
gtl::ArraySlice<int64> target_tensor_ids,
gtl::ArraySlice<int64> source_tensor_ids,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) {
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
BackpropInitialState<BackwardFunction> state = PrepareBackprop(
target_tensor_ids, tensor_tape_, std::move(op_tape_), sources_set);
std::vector<int64> op_stack =
InitialStack(state.op_tape, state.op_missing_tensor);
std::unordered_map<int64, std::vector<Gradient*>> gradients;
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
tensor_tape_, state.op_tape,
state.tensor_usage_counts, &gradients);
auto cleanup = [&state]() {
// Release all backprop functions
for (const auto& pair : state.op_tape) {
pair.second.backward_function_deleter();
}
};
if (!s.ok()) {
cleanup();
return s;
}
std::unordered_map<int64, int64> gradients_size;
// TODO(apassos) multiple threads could be dequeuing from op_stack at the same
// time, for better CPU backprop performance.
VLOG(1) << "Initial stack:";
if (VLOG_IS_ON(1)) {
for (auto t : op_stack) {
VLOG(1) << " " << t;
}
}
std::unordered_map<string, std::unordered_set<int>>
functions_accept_none_for_indices({
{"SoftmaxCrossEntropyWithLogits", {1}},
{"FusedBatchNorm", {1, 2, 3, 4}},
});
while (!op_stack.empty()) {
const int64 op = op_stack.back();
VLOG(1) << "Popped " << op;
op_stack.pop_back();
auto op_it = state.op_tape.find(op);
if (op_it == state.op_tape.end()) {
// It is possible for ops to end up on the stack if they are unrelated to
// the target; we should just skip them.
continue;
}
auto trace = std::move(op_it->second);
state.op_tape.erase(op_it);
std::vector<Gradient*> out_gradients;
out_gradients.reserve(trace.output_tensor_info.size());
for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
const int64 id = trace.output_tensor_info[i].id;
auto grad_it = gradients.find(id);
if (grad_it == gradients.end()) {
auto func_name_it =
functions_accept_none_for_indices.find(trace.op_type);
if (func_name_it != functions_accept_none_for_indices.end() &&
func_name_it->second.find(i) != func_name_it->second.end()) {
out_gradients.push_back(nullptr);
} else {
out_gradients.push_back(
vspace.Zeros(trace.output_tensor_info[i].shape,
trace.output_tensor_info[i].dtype));
}
} else {
out_gradients.push_back(vspace.AggregateGradients(grad_it->second));
if (sources_set.find(grad_it->first) == sources_set.end()) {
gradients.erase(grad_it);
}
}
}
std::vector<Gradient*> in_gradients;
Status s = vspace.CallBackwardFunction(trace.backward_function,
out_gradients, &in_gradients);
if (!s.ok()) {
VLOG(1) << "Gradient function failed.";
cleanup();
return s;
}
VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
<< trace.input_tensor_id.size() << " sources";
for (int i = 0; i < in_gradients.size(); ++i) {
const int64 id = trace.input_tensor_id[i];
if (in_gradients[i] != nullptr) {
auto& unaggregated_grads = gradients[id];
unaggregated_grads.push_back(in_gradients[i]);
if (unaggregated_grads.size() > kMinAggregateCount) {
auto size_it = gradients_size.find(id);
int64 size;
if (size_it == gradients_size.end()) {
size = vspace.NumElements(unaggregated_grads[0]);
gradients_size.emplace(id, size);
} else {
size = size_it->second;
}
if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
Gradient* grad = vspace.AggregateGradients(unaggregated_grads);
unaggregated_grads.clear();
unaggregated_grads.push_back(grad);
}
}
}
auto usage_count_it = state.tensor_usage_counts.find(id);
if (usage_count_it == state.tensor_usage_counts.end()) {
VLOG(1) << "Tensor " << id << " not used";
continue;
}
usage_count_it->second--;
if (usage_count_it->second > 0) {
VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
continue;
}
auto tape_it = tensor_tape_.find(id);
if (tape_it == tensor_tape_.end()) {
VLOG(1) << "Tensor " << id
<< " has no associated op. Deleting gradient";
auto grad_it = gradients.find(id);
if (grad_it != gradients.end()) {
for (auto g : grad_it->second) {
vspace.DeleteGradient(g);
}
gradients.erase(grad_it);
}
continue;
}
const int64 op_id = tape_it->second;
if (op_id == -1) {
VLOG(1) << "Tensor " << id << " is source";
continue;
}
auto missing_it = state.op_missing_tensor.find(op_id);
if (missing_it != state.op_missing_tensor.end()) {
missing_it->second--;
VLOG(1) << "Op " << op_id << " missing " << missing_it->second
<< " output gradients";
if (missing_it->second == 0) {
op_stack.push_back(op_id);
}
}
}
}
CHECK(state.op_tape.empty());
result->reserve(source_tensor_ids.size());
for (auto is : source_tensor_ids) {
auto grad_it = gradients.find(is);
if (grad_it == gradients.end()) {
result->push_back(nullptr);
} else {
if (grad_it->second.size() == 1) {
result->push_back(grad_it->second[0]);
} else {
result->push_back(vspace.AggregateGradients(grad_it->second));
}
gradients.erase(grad_it);
}
}
VLOG(1) << "Final gradients size: " << gradients.size();
for (auto grad_pair : gradients) {
for (const auto& g : grad_pair.second) {
vspace.DeleteGradient(g);
}
}
return Status::OK();
}
} // namespace eager } // namespace eager
} // namespace tensorflow } // namespace tensorflow

View File

@ -119,7 +119,7 @@ def tf_library(name, graph, config,
out_nodes_file, out_nodes_file,
] + freeze_saver_srcs, ] + freeze_saver_srcs,
outs=[freeze_file], outs=[freeze_file],
cmd=("$(location //tensorflow/python/tools:freeze_graph)" + cmd=("$(location @org_tensorflow//tensorflow/python/tools:freeze_graph)" +
freeze_args), freeze_args),
tools=["@org_tensorflow//tensorflow/python/tools:freeze_graph"], tools=["@org_tensorflow//tensorflow/python/tools:freeze_graph"],
tags=tags, tags=tags,
@ -130,6 +130,10 @@ def tf_library(name, graph, config,
header_file = name + ".h" header_file = name + ".h"
object_file = name + ".o" object_file = name + ".o"
ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_") ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_")
if type(tfcompile_flags) == type(""):
flags = tfcompile_flags
else:
flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])])
native.genrule( native.genrule(
name=("gen_" + name), name=("gen_" + name),
srcs=[ srcs=[
@ -148,7 +152,7 @@ def tf_library(name, graph, config,
" --target_triple=" + target_llvm_triple() + " --target_triple=" + target_llvm_triple() +
" --out_header=$(@D)/" + header_file + " --out_header=$(@D)/" + header_file +
" --out_object=$(@D)/" + object_file + " --out_object=$(@D)/" + object_file +
" " + (tfcompile_flags or "")), flags),
tools=[tfcompile_tool], tools=[tfcompile_tool],
visibility=visibility, visibility=visibility,
testonly=testonly, testonly=testonly,
@ -185,7 +189,7 @@ def tf_library(name, graph, config,
" --cpp_class=" + cpp_class + " --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() + " --target_triple=" + target_llvm_triple() +
" --out_session_module=$(@D)/" + session_module_pb + " --out_session_module=$(@D)/" + session_module_pb +
" " + (tfcompile_flags or "")), flags),
tools=[tfcompile_tool], tools=[tfcompile_tool],
visibility=visibility, visibility=visibility,
testonly=testonly, testonly=testonly,
@ -195,8 +199,7 @@ def tf_library(name, graph, config,
# The cc_library rule packaging up the header and object file, and needed # The cc_library rule packaging up the header and object file, and needed
# kernel implementations. # kernel implementations.
need_xla_data_proto = (tfcompile_flags and need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1)
tfcompile_flags.find("--gen_program_shape") != -1)
native.cc_library( native.cc_library(
name=name, name=name,
srcs=[object_file], srcs=[object_file],
@ -253,7 +256,7 @@ def tf_library(name, graph, config,
], ],
outs=[test_file], outs=[test_file],
cmd=("sed " + sed_replace + cmd=("sed " + sed_replace +
" $(location //tensorflow/compiler/aot:test.cc) " + " $(location @org_tensorflow//tensorflow/compiler/aot:test.cc) " +
"> $(OUTS)"), "> $(OUTS)"),
tags=tags, tags=tags,
) )

View File

@ -257,6 +257,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version(); options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
options.local_executable_has_hybrid_result = true;
const XlaCompiler::CompilationResult* kernel; const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable; xla::LocalExecutable* executable;

View File

@ -227,7 +227,10 @@ Status XlaCompilationCache::BuildExecutable(
} }
xla::ExecutableBuildOptions build_options; xla::ExecutableBuildOptions build_options;
build_options.set_device_ordinal(client_->default_device_ordinal()); build_options.set_device_ordinal(client_->default_device_ordinal());
build_options.set_platform(client_->platform());
build_options.set_result_layout(result.xla_output_shape); build_options.set_result_layout(result.xla_output_shape);
build_options.set_has_hybrid_result(
options.local_executable_has_hybrid_result);
auto compile_result = auto compile_result =
client_->Compile(*result.computation, argument_layouts, build_options); client_->Compile(*result.computation, argument_layouts, build_options);

View File

@ -657,7 +657,7 @@ tf_library(
cpp_class = "LSTMLayerInference", cpp_class = "LSTMLayerInference",
graph = "lstm_layer_inference.pbtxt", graph = "lstm_layer_inference.pbtxt",
tags = ["manual"], tags = ["manual"],
tfcompile_flags = "--xla_cpu_multi_thread_eigen=false", tfcompile_flags = ["--xla_cpu_multi_thread_eigen=false"],
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@ -36,7 +36,7 @@ class FusedBatchNormTest(XLATestCase):
x_square = x * x x_square = x * x
x_square_sum = np.sum(x_square, (0, 1, 2)) x_square_sum = np.sum(x_square, (0, 1, 2))
x_sum = np.sum(x, axis=(0, 1, 2)) x_sum = np.sum(x, axis=(0, 1, 2))
element_count = np.size(x) / int(np.shape(x)[0]) element_count = np.size(x) / int(np.shape(x)[-1])
mean = x_sum / element_count mean = x_sum / element_count
var = x_square_sum / element_count - mean * mean var = x_square_sum / element_count - mean * mean
normalized = (x - mean) / np.sqrt(var + epsilon) normalized = (x - mean) / np.sqrt(var + epsilon)
@ -64,8 +64,9 @@ class FusedBatchNormTest(XLATestCase):
return grad_x, grad_scale, grad_offset return grad_x, grad_scale, grad_offset
def testInference(self): def testInference(self):
x_shape = [2, 2, 6, 2] channel = 3
scale_shape = [2] x_shape = [2, 2, 6, channel]
scale_shape = [channel]
x_val = np.random.random_sample(x_shape).astype(np.float32) x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32)
@ -74,8 +75,8 @@ class FusedBatchNormTest(XLATestCase):
with self.test_session() as sess, self.test_scope(): with self.test_session() as sess, self.test_scope():
# To avoid constant folding # To avoid constant folding
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
scale = array_ops.placeholder(np.float32, shape=[2], name="scale") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
offset = array_ops.placeholder(np.float32, shape=[2], name="offset") offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset")
epsilon = 0.001 epsilon = 0.001
y_ref, mean_ref, var_ref = self._reference_training( y_ref, mean_ref, var_ref = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format) x_val, scale_val, offset_val, epsilon, data_format)
@ -97,8 +98,9 @@ class FusedBatchNormTest(XLATestCase):
self.assertAllClose(y_val, y_ref, atol=1e-3) self.assertAllClose(y_val, y_ref, atol=1e-3)
def _testLearning(self, use_gradient_checker): def _testLearning(self, use_gradient_checker):
x_shape = [2, 2, 6, 2] channel = 3
scale_shape = [2] x_shape = [2, 2, 6, channel]
scale_shape = [channel]
x_val = np.random.random_sample(x_shape).astype(np.float32) x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32)
@ -109,8 +111,8 @@ class FusedBatchNormTest(XLATestCase):
with self.test_session() as sess, self.test_scope(): with self.test_session() as sess, self.test_scope():
# To avoid constant folding # To avoid constant folding
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
scale = array_ops.placeholder(np.float32, shape=[2], name="scale") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
offset = array_ops.placeholder(np.float32, shape=[2], name="offset") offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset")
epsilon = 0.001 epsilon = 0.001
y, mean, var = nn.fused_batch_norm( y, mean, var = nn.fused_batch_norm(
t_val, t_val,
@ -154,8 +156,9 @@ class FusedBatchNormTest(XLATestCase):
def testGradient(self): def testGradient(self):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with # TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation. # this reference implementation.
x_shape = [2, 2, 6, 2] channel = 3
scale_shape = [2] x_shape = [2, 2, 6, channel]
scale_shape = [channel]
grad_val = np.random.random_sample(x_shape).astype(np.float32) grad_val = np.random.random_sample(x_shape).astype(np.float32)
x_val = np.random.random_sample(x_shape).astype(np.float32) x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32)

View File

@ -49,9 +49,6 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
case tensorflow::DT_UINT64: case tensorflow::DT_UINT64:
*type = xla::U64; *type = xla::U64;
return Status::OK(); return Status::OK();
case tensorflow::DT_BFLOAT16:
*type = xla::BF16;
return Status::OK();
case tensorflow::DT_HALF: case tensorflow::DT_HALF:
*type = xla::F16; *type = xla::F16;
return Status::OK(); return Status::OK();

View File

@ -236,6 +236,12 @@ class XlaCompiler {
// to the computation. // to the computation.
bool allow_cpu_custom_calls = false; bool allow_cpu_custom_calls = false;
// If 'local_executable_has_hybrid_result', the top-level pointers of the
// result tuple of compiled programs are stored in host memory and the
// nested buffers in device memory, otherwise the whole result tuple is
// stored in device memory.
bool local_executable_has_hybrid_result = false;
// If not nullptr, populate_resource_manager is called with the // If not nullptr, populate_resource_manager is called with the
// compilation device's resource manager when the compilation // compilation device's resource manager when the compilation
// device is created, and can be used to create metadata objects // device is created, and can be used to create metadata objects

View File

@ -77,7 +77,6 @@ cc_library(
hdrs = ["types.h"], hdrs = ["types.h"],
visibility = [":friends"], visibility = [":friends"],
deps = [ deps = [
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//third_party/eigen3", "//third_party/eigen3",
], ],
@ -340,7 +339,6 @@ cc_library(
name = "array", name = "array",
hdrs = ["array.h"], hdrs = ["array.h"],
deps = [ deps = [
":status",
":types", ":types",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],

View File

@ -23,10 +23,8 @@ limitations under the License.
#include <iterator> #include <iterator>
#include <memory> #include <memory>
#include <random> #include <random>
#include <type_traits>
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
@ -37,63 +35,10 @@ limitations under the License.
namespace xla { namespace xla {
namespace array_impl {
// conjunction
//
// Performs a compile-time logical AND operation on the passed types (which
// must have `::value` members convertible to `bool`. Short-circuits if it
// encounters any `false` members (and does not compare the `::value` members
// of any remaining arguments).
//
// This metafunction is designed to be a drop-in replacement for the C++17
// `std::conjunction` metafunction.
template <typename... Ts>
struct conjunction;
template <typename T, typename... Ts>
struct conjunction<T, Ts...>
: std::conditional<T::value, conjunction<Ts...>, T>::type {};
template <>
struct conjunction<> : std::true_type {};
// A type trait that is valid when all elements in a parameter pack are of
// integral type.
template <typename... T>
using pack_is_integral = conjunction<std::is_integral<T>...>;
// Compares three same-sized vectors elementwise. For each item in `values`,
// returns false if any of values[i] is outside the half-open range [starts[i],
// ends[i]).
template <typename C1, typename C2, typename C3>
bool all_inside_range(const C1& values, const C2& range_starts,
const C3& range_ends) {
for (size_t i = 0, e = values.size(); i < e; ++i) {
if (values[i] < range_starts[i] || values[i] >= range_ends[i]) {
return false;
}
}
return true;
}
} // namespace array_impl
// General N dimensional array class with arbitrary value type. // General N dimensional array class with arbitrary value type.
template <typename T> template <typename T>
class Array { class Array {
public: public:
// Type inference can have a hard time parsing very deep initializer list
// nests, especially if one or more dimensions is one as the compiler just
// sees a single-element integer initializer. These typedefs allow casting
// explicitly with less typing.
using InitializerList1D = std::initializer_list<T>;
using InitializerList2D = std::initializer_list<InitializerList1D>;
using InitializerList3D = std::initializer_list<InitializerList2D>;
using InitializerList4D = std::initializer_list<InitializerList3D>;
using value_type = T;
// Creates a new array with the specified dimensions. // Creates a new array with the specified dimensions.
explicit Array(tensorflow::gtl::ArraySlice<int64> sizes) explicit Array(tensorflow::gtl::ArraySlice<int64> sizes)
: Array(sizes, T()) {} : Array(sizes, T()) {}
@ -108,7 +53,7 @@ class Array {
// Creates a 2D array from the given nested initializer list. The outer // Creates a 2D array from the given nested initializer list. The outer
// initializer list is the first dimension, the inner is the second dimension. // initializer list is the first dimension, the inner is the second dimension.
// For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3. // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3.
Array(InitializerList2D values) Array(std::initializer_list<std::initializer_list<T>> values)
: Array(ToInt64Vector({values.size(), values.begin()->size()})) { : Array(ToInt64Vector({values.size(), values.begin()->size()})) {
int64 idx = 0; int64 idx = 0;
for (const auto& it1 : values) { for (const auto& it1 : values) {
@ -122,7 +67,8 @@ class Array {
// Creates a 3D array from the given nested initializer list. The outer // Creates a 3D array from the given nested initializer list. The outer
// initializer list is the first dimension, and so on. // initializer list is the first dimension, and so on.
Array(InitializerList3D values) Array(std::initializer_list<std::initializer_list<std::initializer_list<T>>>
values)
: Array(ToInt64Vector({values.size(), values.begin()->size(), : Array(ToInt64Vector({values.size(), values.begin()->size(),
values.begin()->begin()->size()})) { values.begin()->begin()->size()})) {
int64 idx = 0; int64 idx = 0;
@ -139,7 +85,9 @@ class Array {
// Creates a 4D array from the given nested initializer list. The outer // Creates a 4D array from the given nested initializer list. The outer
// initializer list is the first dimension, and so on. // initializer list is the first dimension, and so on.
Array(InitializerList4D values) Array(std::initializer_list<
std::initializer_list<std::initializer_list<std::initializer_list<T>>>>
values)
: Array(ToInt64Vector({values.size(), values.begin()->size(), : Array(ToInt64Vector({values.size(), values.begin()->size(),
values.begin()->begin()->size(), values.begin()->begin()->size(),
values.begin()->begin()->begin()->size()})) { values.begin()->begin()->begin()->size()})) {
@ -225,46 +173,10 @@ class Array {
} }
} }
// Invokes a callback with the (indices, value_ptr) for each cell in the
// array. If a callback returns a non-OK status, returns that else returns
// Status::OK().
Status EachStatus(
std::function<Status(tensorflow::gtl::ArraySlice<int64>, T*)> f) {
std::vector<int64> index(sizes_.size());
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
Status s = f(index, &values_[i]);
if (!s.ok()) {
return s;
}
}
return Status::OK();
}
// Invokes a callback with the (indices, value) for each cell in the array.
// If a callback returns a non-OK status, returns that else returns
// Status::OK().
Status EachStatus(
std::function<Status(tensorflow::gtl::ArraySlice<int64>, T)> f) const {
std::vector<int64> index(sizes_.size());
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
Status s = f(index, values_[i]);
if (!s.ok()) {
return s;
}
}
return Status::OK();
}
// Returns the value at the cell specified by the indexes. The number of // Returns the value at the cell specified by the indexes. The number of
// arguments have to match with the number of dimensions for the array. // arguments have to match with the number of dimensions for the array.
//
// The type trait is required to avoid this overload participating too
// eagerly; a parameter pack can take zero or more elements, so we must
// restrict this to only parameter packs that are all of integral type.
template <typename... Dims> template <typename... Dims>
typename std::enable_if<array_impl::pack_is_integral<Dims...>::value, const T& operator()(Dims... dims) const {
const T&>::type
operator()(Dims... dims) const {
// We are using a std::array to avoid having to allocate memory in this // We are using a std::array to avoid having to allocate memory in this
// function for performance reasons. // function for performance reasons.
std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}}; std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
@ -274,9 +186,7 @@ class Array {
// Returns the value at the cell specified by the indexes. The number of // Returns the value at the cell specified by the indexes. The number of
// arguments have to match with the number of dimensions for the array. // arguments have to match with the number of dimensions for the array.
template <typename... Dims> template <typename... Dims>
typename std::enable_if<array_impl::pack_is_integral<Dims...>::value, T& operator()(Dims... dims) {
T&>::type
operator()(Dims... dims) {
// We are using a std::array to avoid having to allocate memory in this // We are using a std::array to avoid having to allocate memory in this
// function for performance reasons. // function for performance reasons.
std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}}; std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
@ -345,59 +255,6 @@ class Array {
bool operator!=(const Array<T>& other) const { return !(*this == other); } bool operator!=(const Array<T>& other) const { return !(*this == other); }
// Performs the equivalent of a slice operation on this array.
Array<T> Slice(tensorflow::gtl::ArraySlice<int64> starts,
tensorflow::gtl::ArraySlice<int64> limits) const {
CHECK_EQ(starts.size(), num_dimensions());
CHECK_EQ(limits.size(), num_dimensions());
std::vector<int64> sizes;
std::transform(starts.begin(), starts.end(), limits.begin(),
std::back_inserter(sizes),
[](int64 start, int64 limit) { return limit - start; });
Array<T> result(sizes);
std::vector<int64> index(sizes_.size());
int64 slice_i = 0;
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
if (array_impl::all_inside_range(index, starts, limits)) {
// Even though the bounds of result are different to our bounds, we're
// iterating in the same order. So we can simply write successive linear
// indices instead of recalculating a multi-dimensional index.
result.values_[slice_i++] = values_[i];
}
}
return result;
}
// Performs the equivalent of a DynamicUpdateSlice in-place on this array.
void UpdateSlice(const Array<T>& from,
tensorflow::gtl::ArraySlice<int64> start_indices) {
CHECK_EQ(from.num_dimensions(), num_dimensions());
std::vector<int64> limit_indices;
std::transform(start_indices.begin(), start_indices.end(),
from.dimensions().begin(), std::back_inserter(limit_indices),
std::plus<int64>{});
std::vector<int64> index(sizes_.size());
int64 from_i = 0;
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
if (array_impl::all_inside_range(index, start_indices, limit_indices)) {
// Even though the bounds of from are different to our bounds, we're
// iterating in the same order. So we can simply write successive linear
// indices instead of recalculating a multi-dimensional index.
values_[i] = from.values_[from_i++];
}
}
}
// Performs an in-place reshape, modifying the dimensions but not the
// underlying data.
void Reshape(tensorflow::gtl::ArraySlice<int64> new_dimensions) {
int64 old_num_elements = num_elements();
sizes_ = std::vector<int64>(new_dimensions.begin(), new_dimensions.end());
CHECK_EQ(num_elements(), old_num_elements);
}
// Returns a string representation of the array suitable for debugging. // Returns a string representation of the array suitable for debugging.
string ToString() const { string ToString() const {
std::vector<string> pieces; std::vector<string> pieces;

View File

@ -71,19 +71,6 @@ TEST(ArrayTest, IndexingReadWrite) {
EXPECT_EQ(arr(1, 2), 61); EXPECT_EQ(arr(1, 2), 61);
} }
TEST(ArrayTest, DynamicIndexingReadWrite) {
Array<int> arr({2, 3});
std::vector<int64> index1 = {1, 1};
std::vector<int64> index2 = {1, 2};
EXPECT_EQ(arr(index1), 0);
EXPECT_EQ(arr(index2), 0);
arr(index1) = 51;
arr(index2) = 61;
EXPECT_EQ(arr(1, 1), 51);
EXPECT_EQ(arr(1, 2), 61);
}
TEST(ArrayTest, IndexingReadWriteBool) { TEST(ArrayTest, IndexingReadWriteBool) {
Array<bool> arr{{false, true, false}, {false, true, false}}; Array<bool> arr{{false, true, false}, {false, true, false}};
@ -154,37 +141,5 @@ TEST(ArrayTest, Each) {
EXPECT_EQ(arr.num_elements() * (arr.num_elements() - 1) / 2, each_sum); EXPECT_EQ(arr.num_elements() * (arr.num_elements() - 1) / 2, each_sum);
} }
TEST(ArrayTest, Slice) {
Array<int64> arr({2, 4});
arr.FillWithMultiples(1);
Array<int64> identity_slice = arr.Slice({0, 0}, {2, 4});
EXPECT_EQ(identity_slice.dimensions(), arr.dimensions());
for (auto it1 = arr.begin(), it2 = identity_slice.begin(), e = arr.end();
it1 != e; ++it1, ++it2) {
EXPECT_EQ(*it1, *it2);
}
Array<int64> sub_slice = arr.Slice({1, 0}, {2, 2});
EXPECT_EQ(sub_slice.dimensions(), (std::vector<int64>{1, 2}));
const string expected = R"([[4, 5]])";
EXPECT_EQ(expected, sub_slice.ToString());
}
TEST(ArrayTest, UpdateSlice) {
Array<int64> arr({3, 4});
arr.FillWithMultiples(1);
Array<int64> sub_arr({2, 2});
sub_arr.FillWithMultiples(3);
arr.UpdateSlice(sub_arr, {1, 1});
const string expected = R"([[0, 1, 2, 3],
[4, 0, 3, 7],
[8, 6, 9, 11]])";
EXPECT_EQ(expected, arr.ToString());
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -142,7 +142,8 @@ StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
"TransferToClient request"); "TransferToClient request");
} }
return MakeUnique<Literal>(response.literal()); Literal literal(response.literal());
return MakeUnique<Literal>(literal);
} }
Status Client::ResetDevice() { Status Client::ResetDevice() {

View File

@ -68,7 +68,6 @@ class ShardingBuilder {
const TileAssignment& tile_assignment) { const TileAssignment& tile_assignment) {
OpSharding result; OpSharding result;
result.set_type(OpSharding::Type::OpSharding_Type_OTHER); result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
*result.mutable_tile_shape() = tile_shape;
for (int64 dim : tile_assignment.dimensions()) { for (int64 dim : tile_assignment.dimensions()) {
result.add_tile_assignment_dimensions(dim); result.add_tile_assignment_dimensions(dim);
} }

View File

@ -44,7 +44,6 @@ cc_library(
"//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
) )

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
@ -49,6 +48,62 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
} // namespace } // namespace
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
if (ShapeUtil::IsTuple(shape)) {
std::vector<std::unique_ptr<Literal>> elements;
for (const Shape& element_shape : shape.tuple_shapes()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
MakeFakeLiteral(element_shape));
elements.push_back(std::move(element));
}
return Literal::MakeTupleOwned(std::move(elements));
}
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
std::minstd_rand0 engine;
switch (shape.element_type()) {
case F32: {
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
TF_CHECK_OK(literal->Populate<float>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
case S32: {
std::uniform_int_distribution<int32> generator(
std::numeric_limits<int32>::lowest(),
std::numeric_limits<int32>::max());
TF_CHECK_OK(literal->Populate<int32>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
case S64: {
std::uniform_int_distribution<int64> generator(
std::numeric_limits<int64>::lowest(),
std::numeric_limits<int64>::max());
TF_CHECK_OK(literal->Populate<int64>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
case PRED: {
std::uniform_int_distribution<int> generator(0, 1);
TF_CHECK_OK(literal->Populate<bool>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
}));
break;
}
default:
return Unimplemented("Unsupported type for fake literal generation: %s",
ShapeUtil::HumanString(shape).c_str());
}
return std::move(literal);
}
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape, std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
Client* client) { Client* client) {
if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) { if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) {

View File

@ -26,6 +26,10 @@ limitations under the License.
namespace xla { namespace xla {
// Generates fake data in a literal of the given shape, or returns an error
// status if the element type is currently unhandled for fake data generation.
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
// Generates fake data of the given shape on the device or dies. The fake data // Generates fake data of the given shape on the device or dies. The fake data
// is created by performing a computation on the device rather than transferring // is created by performing a computation on the device rather than transferring
// data from the host to the device. // data from the host to the device.

View File

@ -27,6 +27,16 @@ namespace se = ::perftools::gputools;
namespace xla { namespace xla {
ExecutableBuildOptions& ExecutableBuildOptions::set_platform(
perftools::gputools::Platform* platform) {
platform_ = platform;
return *this;
}
perftools::gputools::Platform* ExecutableBuildOptions::platform() const {
return platform_;
}
ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal( ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal(
int device_ordinal) { int device_ordinal) {
device_ordinal_ = device_ordinal; device_ordinal_ = device_ordinal;
@ -46,6 +56,16 @@ const Shape* ExecutableBuildOptions::result_layout() const {
return result_layout_set_ ? &result_layout_ : nullptr; return result_layout_set_ ? &result_layout_ : nullptr;
} }
ExecutableBuildOptions& ExecutableBuildOptions::set_has_hybrid_result(
bool has_hybrid_result) {
has_hybrid_result_ = has_hybrid_result;
return *this;
}
bool ExecutableBuildOptions::has_hybrid_result() const {
return has_hybrid_result_;
}
namespace { namespace {
StatusOr<Backend::StreamPtr> BorrowStreamForDevice(int device_ordinal, StatusOr<Backend::StreamPtr> BorrowStreamForDevice(int device_ordinal,
Backend* backend) { Backend* backend) {

View File

@ -37,6 +37,14 @@ namespace xla {
// LocalClient::Compile. // LocalClient::Compile.
class ExecutableBuildOptions { class ExecutableBuildOptions {
public: public:
// If set, this is the platform to build the computation for. This must match
// the underlying platform of the service. A value of nullptr indicates the
// option has not been set.
//
// TODO(b/28616830): Support multiple platforms.
ExecutableBuildOptions& set_platform(perftools::gputools::Platform* platform);
perftools::gputools::Platform* platform() const;
// If set, this is the device to build the computation for. Valid // If set, this is the device to build the computation for. Valid
// device_ordinal values are: 0 to # of devices - 1. These values are // device_ordinal values are: 0 to # of devices - 1. These values are
// identical to the device ordinal values used by StreamExecutor. The built // identical to the device ordinal values used by StreamExecutor. The built
@ -53,10 +61,18 @@ class ExecutableBuildOptions {
ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout);
const Shape* result_layout() const; const Shape* result_layout() const;
// If set, the executable will be built to output a hybrid
// ShapedBuffer with top-level tuple pointers in host memory and
// result buffers in device memory.
ExecutableBuildOptions& set_has_hybrid_result(bool has_hybrid_result);
bool has_hybrid_result() const;
private: private:
perftools::gputools::Platform* platform_ = nullptr;
int device_ordinal_ = -1; int device_ordinal_ = -1;
Shape result_layout_; Shape result_layout_;
bool result_layout_set_ = false; bool result_layout_set_ = false;
bool has_hybrid_result_ = true;
}; };
class LocalExecutable { class LocalExecutable {

View File

@ -33,20 +33,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
namespace {
using tensorflow::int64;
constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
// Converts between little and big endian, assuming elements in the array are 16
// bits long.
void ConvertEndianShort(char* bytes, int64 size) {
CHECK_EQ(size / 2, 0);
for (int64 i = 0; i < size; i += 2) {
std::swap(bytes[i], bytes[i + 1]);
}
}
} // namespace
namespace xla { namespace xla {
@ -183,8 +169,6 @@ Status Literal::Copy(const Literal& src_literal,
return CopyRange<int64>(src_literal, src_base, dest_base, copy_size); return CopyRange<int64>(src_literal, src_base, dest_base, copy_size);
case F16: case F16:
return CopyRange<half>(src_literal, src_base, dest_base, copy_size); return CopyRange<half>(src_literal, src_base, dest_base, copy_size);
case BF16:
return CopyRange<bfloat16>(src_literal, src_base, dest_base, copy_size);
case F32: case F32:
return CopyRange<float>(src_literal, src_base, dest_base, copy_size); return CopyRange<float>(src_literal, src_base, dest_base, copy_size);
case F64: case F64:
@ -216,8 +200,6 @@ Status Literal::Copy(const Literal& src_literal,
return *Literal::CreateR0<int64>(0); return *Literal::CreateR0<int64>(0);
case F16: case F16:
return *Literal::CreateR0<half>(static_cast<half>(0.0f)); return *Literal::CreateR0<half>(static_cast<half>(0.0f));
case BF16:
return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
case F32: case F32:
return *Literal::CreateR0<float>(0); return *Literal::CreateR0<float>(0);
case F64: case F64:
@ -303,9 +285,6 @@ Status Literal::Copy(const Literal& src_literal,
case F16: case F16:
return *Literal::CreateR0<half>( return *Literal::CreateR0<half>(
static_cast<half>(-std::numeric_limits<float>::infinity())); static_cast<half>(-std::numeric_limits<float>::infinity()));
case BF16:
return *Literal::CreateR0<bfloat16>(
static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
case TUPLE: case TUPLE:
LOG(FATAL) << "tuple element type has no minimum value"; LOG(FATAL) << "tuple element type has no minimum value";
case OPAQUE: case OPAQUE:
@ -342,9 +321,6 @@ Status Literal::Copy(const Literal& src_literal,
case F16: case F16:
return *Literal::CreateR0<half>( return *Literal::CreateR0<half>(
static_cast<half>(std::numeric_limits<float>::infinity())); static_cast<half>(std::numeric_limits<float>::infinity()));
case BF16:
return *Literal::CreateR0<bfloat16>(
static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
case TUPLE: case TUPLE:
LOG(FATAL) << "tuple element type has no maximum value"; LOG(FATAL) << "tuple element type has no maximum value";
case OPAQUE: case OPAQUE:
@ -452,7 +428,6 @@ std::unique_ptr<Literal> Literal::Transpose(
// The shape with affine layout resulting from that operation will be // The shape with affine layout resulting from that operation will be
// F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
// most minor. // most minor.
//
// Essentially, given MinMaj(Di) the position of the Di dimension within the // Essentially, given MinMaj(Di) the position of the Di dimension within the
// minor to major vector, and given T(Di) the index that the original Di // minor to major vector, and given T(Di) the index that the original Di
// dimension has within the transposed array, a layout is affine if // dimension has within the transposed array, a layout is affine if
@ -561,9 +536,6 @@ string Literal::GetAsString(
} }
case F16: case F16:
return tensorflow::strings::StrCat(Get<half>(multi_index)); return tensorflow::strings::StrCat(Get<half>(multi_index));
case BF16:
return tensorflow::strings::StrCat(
static_cast<float>(Get<bfloat16>(multi_index)));
default: default:
return tensorflow::strings::StrCat( return tensorflow::strings::StrCat(
"[", PrimitiveType_Name(shape().element_type()), "]"); "[", PrimitiveType_Name(shape().element_type()), "]");
@ -597,17 +569,9 @@ int64 Literal::LinearIndex(
return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index); return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
} }
string Literal::ToString(bool print_layout) const { string Literal::ToString() const {
std::vector<string> pieces; std::vector<string> pieces;
auto shape_to_string = [print_layout](const Shape& shape) {
if (print_layout) {
return ShapeUtil::HumanStringWithLayout(shape);
} else {
return ShapeUtil::HumanString(shape);
}
};
auto element_to_string = auto element_to_string =
[this](tensorflow::gtl::ArraySlice<int64> indices) -> string { [this](tensorflow::gtl::ArraySlice<int64> indices) -> string {
PrimitiveType element_type = shape().element_type(); PrimitiveType element_type = shape().element_type();
@ -621,7 +585,7 @@ string Literal::ToString(bool print_layout) const {
// TODO(b/32894291): refactor this code to reduce code duplication. // TODO(b/32894291): refactor this code to reduce code duplication.
if (ShapeUtil::IsTuple(shape())) { if (ShapeUtil::IsTuple(shape())) {
pieces.push_back(shape_to_string(shape())); pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(" (\n"); pieces.push_back(" (\n");
pieces.push_back(tensorflow::str_util::Join( pieces.push_back(tensorflow::str_util::Join(
tuple_literals(), ",\n", [](string* out, const Literal& element) { tuple_literals(), ",\n", [](string* out, const Literal& element) {
@ -637,7 +601,7 @@ string Literal::ToString(bool print_layout) const {
} }
pieces.push_back("}"); pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 2) { } else if (ShapeUtil::Rank(shape()) == 2) {
pieces.push_back(shape_to_string(shape())); pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(" {\n"); pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(" { "); pieces.push_back(" { ");
@ -649,7 +613,7 @@ string Literal::ToString(bool print_layout) const {
} }
pieces.push_back("}"); pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 3) { } else if (ShapeUtil::Rank(shape()) == 3) {
pieces.push_back(shape_to_string(shape())); pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(" {\n"); pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(i0 > 0 ? ",\n{" : "{"); pieces.push_back(i0 > 0 ? ",\n{" : "{");
@ -664,7 +628,7 @@ string Literal::ToString(bool print_layout) const {
} }
pieces.push_back("\n}"); pieces.push_back("\n}");
} else if (ShapeUtil::Rank(shape()) == 4) { } else if (ShapeUtil::Rank(shape()) == 4) {
pieces.push_back(shape_to_string(shape())); pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(" {\n"); pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
@ -685,7 +649,7 @@ string Literal::ToString(bool print_layout) const {
} }
pieces.push_back("}"); pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 5) { } else if (ShapeUtil::Rank(shape()) == 5) {
pieces.push_back(shape_to_string(shape())); pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(" {\n"); pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
@ -712,7 +676,7 @@ string Literal::ToString(bool print_layout) const {
} }
pieces.push_back("}"); pieces.push_back("}");
} else { } else {
pieces.push_back(shape_to_string(shape())); pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(" {...}"); pieces.push_back(" {...}");
} }
@ -771,8 +735,6 @@ void* Literal::MutableInternalData() {
return reinterpret_cast<void*>(c64s_.data()); return reinterpret_cast<void*>(c64s_.data());
case F16: case F16:
return reinterpret_cast<void*>(f16s_.data()); return reinterpret_cast<void*>(f16s_.data());
case BF16:
return reinterpret_cast<void*>(bf16s_.data());
default: default:
LOG(FATAL) << "primitive type not supported in literals: " LOG(FATAL) << "primitive type not supported in literals: "
<< PrimitiveType_Name(shape().element_type()); << PrimitiveType_Name(shape().element_type());
@ -815,9 +777,6 @@ void Literal::Reserve(int64 num_elements) {
case F16: case F16:
Resize<half>(num_elements, static_cast<half>(0.0f)); Resize<half>(num_elements, static_cast<half>(0.0f));
break; break;
case BF16:
Resize<bfloat16>(num_elements, static_cast<bfloat16>(0.0f));
break;
default: default:
LOG(FATAL) << "primitive type not supported in literals: " LOG(FATAL) << "primitive type not supported in literals: "
<< PrimitiveType_Name(shape().element_type()); << PrimitiveType_Name(shape().element_type());
@ -857,9 +816,6 @@ tensorflow::Status Literal::ValidateLiteral() const {
case F16: case F16:
actual = f16s().size() / sizeof(half); actual = f16s().size() / sizeof(half);
break; break;
case BF16:
actual = bf16s().size();
break;
default: default:
return tensorflow::errors::Unimplemented( return tensorflow::errors::Unimplemented(
"unhandled element type for literal validation: " + "unhandled element type for literal validation: " +
@ -956,7 +912,6 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
CONVERT_IF_TYPES_MATCH(F16) CONVERT_IF_TYPES_MATCH(F16)
CONVERT_IF_TYPES_MATCH(F32) CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64) CONVERT_IF_TYPES_MATCH(F64)
CONVERT_IF_TYPES_MATCH(BF16)
#undef CONVERT_IF_TYPES_MATCH #undef CONVERT_IF_TYPES_MATCH
case C64: case C64:
return ConvertToC64<primitive_src_type>(src_literal); return ConvertToC64<primitive_src_type>(src_literal);
@ -986,9 +941,8 @@ StatusOr<std::unique_ptr<Literal>> Literal::Convert(
CONVERT_IF_DEST_TYPE_MATCHES(F16) CONVERT_IF_DEST_TYPE_MATCHES(F16)
CONVERT_IF_DEST_TYPE_MATCHES(F32) CONVERT_IF_DEST_TYPE_MATCHES(F32)
CONVERT_IF_DEST_TYPE_MATCHES(F64) CONVERT_IF_DEST_TYPE_MATCHES(F64)
CONVERT_IF_DEST_TYPE_MATCHES(BF16)
#undef CONVERT_IF_DEST_TYPE_MATCHES #undef CONVERT_IF_DEST_TYPE_MATCHES
// Other types are not yet supported. // Other types are not yet supported.
default: default:
return InvalidArgument("Unimplemented: Convert from type %s to type %s", return InvalidArgument("Unimplemented: Convert from type %s to type %s",
PrimitiveType_Name(shape().element_type()).c_str(), PrimitiveType_Name(shape().element_type()).c_str(),
@ -1057,8 +1011,6 @@ bool Literal::operator==(const Literal& other) const {
return EqualElements<double>(*this, other, 0, &multi_index); return EqualElements<double>(*this, other, 0, &multi_index);
case F16: case F16:
return EqualElements<half>(*this, other, 0, &multi_index); return EqualElements<half>(*this, other, 0, &multi_index);
case BF16:
return EqualElements<bfloat16>(*this, other, 0, &multi_index);
case C64: case C64:
return EqualElements<complex64>(*this, other, 0, &multi_index); return EqualElements<complex64>(*this, other, 0, &multi_index);
default: default:
@ -1168,18 +1120,13 @@ tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice() {
template <> template <>
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() { tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() {
// TODO - there is an endianess problem here. fix it, or wait for uint16
// support in protobuf
auto values = mutable_f16s(); auto values = mutable_f16s();
return tensorflow::gtl::MutableArraySlice<half>(values->data(), return tensorflow::gtl::MutableArraySlice<half>(values->data(),
values->size()); values->size());
} }
template <>
tensorflow::gtl::MutableArraySlice<bfloat16>
Literal::GetMutableArraySlice<bfloat16>() {
auto values = mutable_bf16s();
return {values->data(), values->size()};
}
template <> template <>
tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const { tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const {
CHECK_EQ(shape().element_type(), PRED); CHECK_EQ(shape().element_type(), PRED);
@ -1250,12 +1197,6 @@ tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const {
f16s().size() / sizeof(half)); f16s().size() / sizeof(half));
} }
template <>
tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const {
CHECK_EQ(shape().element_type(), BF16);
return {bf16s().data(), bf16s().size()};
}
template <> template <>
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>() tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
const { const {
@ -1304,9 +1245,6 @@ bool Literal::IsAll(int8 value) const {
return AllElementsEqualValue<double>(*this, value); return AllElementsEqualValue<double>(*this, value);
case F16: case F16:
return AllElementsEqualValue<half>(*this, static_cast<half>(value)); return AllElementsEqualValue<half>(*this, static_cast<half>(value));
case BF16:
return AllElementsEqualValue<bfloat16>(*this,
static_cast<bfloat16>(value));
case PRED: case PRED:
if (value == 0) { if (value == 0) {
return AllElementsEqualValue<bool>(*this, false); return AllElementsEqualValue<bool>(*this, false);
@ -1328,9 +1266,6 @@ bool Literal::IsAllFloat(float value) const {
return AllElementsEqualValue<double>(*this, value); return AllElementsEqualValue<double>(*this, value);
case F16: case F16:
return AllElementsEqualValue<half>(*this, static_cast<half>(value)); return AllElementsEqualValue<half>(*this, static_cast<half>(value));
case BF16:
return AllElementsEqualValue<bfloat16>(*this,
static_cast<bfloat16>(value));
default: default:
return false; return false;
} }
@ -1367,8 +1302,6 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
return Get<complex64>(indices) == complex64(0.0f, 0.0f); return Get<complex64>(indices) == complex64(0.0f, 0.0f);
case F16: case F16:
return Get<half>(indices) == static_cast<half>(0.0f); return Get<half>(indices) == static_cast<half>(0.0f);
case BF16:
return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
case PRED: case PRED:
return Get<bool>(indices) == false; return Get<bool>(indices) == false;
default: default:
@ -1436,12 +1369,6 @@ void Literal::Resize<half>(int64 num_elements, half value) {
mutable_f16s()->resize(num_elements, value); mutable_f16s()->resize(num_elements, value);
} }
template <>
void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value) {
CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
mutable_bf16s()->resize(num_elements, value);
}
template <> template <>
void Literal::Resize<complex64>(int64 num_elements, complex64 value) { void Literal::Resize<complex64>(int64 num_elements, complex64 value) {
CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
@ -1490,19 +1417,6 @@ LiteralProto Literal::ToProto() const {
*proto.mutable_f16s() = *proto.mutable_f16s() =
string(reinterpret_cast<const char*>(f16s_.data()), string(reinterpret_cast<const char*>(f16s_.data()),
f16s_.size() * sizeof(half)); f16s_.size() * sizeof(half));
if (!kLittleEndian) {
ConvertEndianShort(const_cast<char*>(proto.mutable_f16s()->data()),
proto.f16s().size());
}
break;
case BF16:
*proto.mutable_bf16s() =
string(reinterpret_cast<const char*>(bf16s_.data()),
bf16s_.size() * sizeof(bfloat16));
if (!kLittleEndian) {
ConvertEndianShort(const_cast<char*>(proto.mutable_bf16s()->data()),
proto.bf16s().size());
}
break; break;
case F32: case F32:
CopyToRepeatedField(proto.mutable_f32s(), f32s()); CopyToRepeatedField(proto.mutable_f32s(), f32s());
@ -1571,21 +1485,6 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
CHECK_EQ(0, s.size() % sizeof(half)); CHECK_EQ(0, s.size() % sizeof(half));
f16s_ = std::vector<half>(s.size() / sizeof(half)); f16s_ = std::vector<half>(s.size() / sizeof(half));
memcpy(f16s_.data(), s.data(), s.size()); memcpy(f16s_.data(), s.data(), s.size());
if (!kLittleEndian) {
ConvertEndianShort(reinterpret_cast<char*>(f16s_.data()), s.size());
}
break;
}
case BF16: {
const string& s(literal_proto.bf16s());
CHECK_EQ(0, s.size() % sizeof(bfloat16));
bf16s_ = std::vector<bfloat16>(s.size() / sizeof(bfloat16));
memcpy(bf16s_.data(), s.data(), s.size());
if (!kLittleEndian) {
ConvertEndianShort(reinterpret_cast<char*>(bf16s_.data()), s.size());
}
break; break;
} }
case F32: case F32:

View File

@ -163,11 +163,6 @@ class Literal {
const std::vector<complex64>& c64s() const { return c64s_; } const std::vector<complex64>& c64s() const { return c64s_; }
std::vector<complex64>* mutable_c64s() { return &c64s_; } std::vector<complex64>* mutable_c64s() { return &c64s_; }
int bf16s_size() const { return bf16s().size(); }
bfloat16 bf16s(int i) const { return bf16s_[i]; }
const std::vector<bfloat16>& bf16s() const { return bf16s_; }
std::vector<bfloat16>* mutable_bf16s() { return &bf16s_; }
int tuple_literals_size() const { return tuple_literals().size(); } int tuple_literals_size() const { return tuple_literals().size(); }
const Literal& tuple_literals(int i) const { return tuple_literals_[i]; } const Literal& tuple_literals(int i) const { return tuple_literals_[i]; }
Literal* add_tuple_literals() { Literal* add_tuple_literals() {
@ -455,7 +450,7 @@ class Literal {
tensorflow::Status ValidateLiteral() const; tensorflow::Status ValidateLiteral() const;
// Returns a string representation of the literal value. // Returns a string representation of the literal value.
string ToString(bool print_layout = false) const; string ToString() const;
// Invokes the "per cell" callback for each element in the provided // Invokes the "per cell" callback for each element in the provided
// literal with the element's indices and a string representation of // literal with the element's indices and a string representation of
@ -627,7 +622,6 @@ class Literal {
std::vector<uint16> u16s_; std::vector<uint16> u16s_;
std::vector<uint32> u32s_; std::vector<uint32> u32s_;
std::vector<uint64> u64s_; std::vector<uint64> u64s_;
std::vector<bfloat16> bf16s_;
std::vector<half> f16s_; std::vector<half> f16s_;
std::vector<float> f32s_; std::vector<float> f32s_;
std::vector<double> f64s_; std::vector<double> f64s_;
@ -680,9 +674,6 @@ tensorflow::gtl::ArraySlice<double> Literal::GetArraySlice<double>() const;
template <> template <>
tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const; tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const;
template <>
tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const;
template <> template <>
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>() tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
const; const;
@ -723,9 +714,6 @@ tensorflow::gtl::MutableArraySlice<double> Literal::GetMutableArraySlice();
template <> template <>
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice(); tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice();
template <>
tensorflow::gtl::MutableArraySlice<bfloat16> Literal::GetMutableArraySlice();
template <> template <>
tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice(); tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice();
@ -759,9 +747,6 @@ void Literal::Resize<double>(int64 num_elements, double value);
template <> template <>
void Literal::Resize<half>(int64 num_elements, half value); void Literal::Resize<half>(int64 num_elements, half value);
template <>
void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value);
template <> template <>
void Literal::Resize<complex64>(int64 num_elements, complex64 value); void Literal::Resize<complex64>(int64 num_elements, complex64 value);
@ -1005,14 +990,6 @@ inline half Literal::Get<half>(
return GetArraySlice<half>()[linear_index]; return GetArraySlice<half>()[linear_index];
} }
template <>
inline bfloat16 Literal::Get<bfloat16>(
tensorflow::gtl::ArraySlice<int64> multi_index) const {
CHECK(shape().element_type() == BF16);
int64 linear_index = LinearIndex(multi_index);
return GetArraySlice<bfloat16>()[linear_index];
}
template <typename NativeT> template <typename NativeT>
void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index, void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
NativeT value) { NativeT value) {

View File

@ -110,18 +110,6 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f}); auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f});
ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString());
auto bf16_lit = Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
ASSERT_EQ("0.5", bf16_lit->ToString());
// 3.14 will be rounded to 3.125 in bfloat16 format (Round to nearest even).
auto bf16_lit_truncated =
Literal::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
ASSERT_EQ("3.140625", bf16_lit_truncated->ToString());
auto bf16_lit_truncated2 =
Literal::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
ASSERT_EQ("9", bf16_lit_truncated2->ToString());
} }
TEST_F(LiteralUtilTest, LiteralVectorToString) { TEST_F(LiteralUtilTest, LiteralVectorToString) {
@ -409,18 +397,6 @@ TEST_F(LiteralUtilTest, IsAll) {
EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8)); EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8));
EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8)); EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8));
bfloat16 b8(8.0f);
bfloat16 b9(9.0f);
EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
// 9.001 will be truncated to 9.0
bfloat16 b91(9.001f);
bfloat16 b90(9.00f);
EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
complex64 c8_9 = {8, 9}; complex64 c8_9 = {8, 9};
EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8)); EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
@ -715,30 +691,6 @@ TEST_F(LiteralUtilTest, PopulateR2C64) {
EXPECT_EQ(output, *expected); EXPECT_EQ(output, *expected);
} }
TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
Literal output;
bfloat16 h(0.25f);
output.PopulateWithValue<bfloat16>(h, {});
auto expected = Literal::CreateR0<bfloat16>(h);
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
Literal output;
bfloat16 h(0.5f);
output.PopulateWithValue<bfloat16>(h, {3});
auto expected = Literal::CreateR1<bfloat16>({h, h, h});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
Literal output;
bfloat16 h(2.0f);
output.PopulateWithValue<bfloat16>(h, {2, 2});
auto expected = Literal::CreateR2<bfloat16>({{h, h}, {h, h}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output; Literal output;
output.PopulateWithValue<float>(2.5f, {}); output.PopulateWithValue<float>(2.5f, {});
@ -1023,14 +975,6 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{half(26.0), half(0.0), half(28.0), half(0.0)}, {{half(26.0), half(0.0), half(28.0), half(0.0)},
{half(0.0), half(31.0), half(0.0), half(33.0)}}, {half(0.0), half(31.0), half(0.0), half(33.0)}},
}}, layout_r4_dim0major_); }}, layout_r4_dim0major_);
auto bf16 = Literal::CreateR4WithLayout<bfloat16>({{
{{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
{bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
{{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
{bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}},
{{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
{bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
}}, layout_r4_dim0major_);
auto f32 = Literal::CreateR4WithLayout<float>({{ auto f32 = Literal::CreateR4WithLayout<float>({{
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
@ -1064,12 +1008,6 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
conv = s8->Convert(PRED).ConsumeValueOrDie(); conv = s8->Convert(PRED).ConsumeValueOrDie();
EXPECT_EQ(*conv, *pred); EXPECT_EQ(*conv, *pred);
conv = bf16->Convert(S32).ConsumeValueOrDie();
EXPECT_EQ(*conv, *s32);
conv = bf16->Convert(F32).ConsumeValueOrDie();
EXPECT_EQ(*conv, *f32);
conv = pred->Convert(S32).ConsumeValueOrDie(); conv = pred->Convert(S32).ConsumeValueOrDie();
EXPECT_EQ(*conv, *int32_pred); EXPECT_EQ(*conv, *int32_pred);

View File

@ -78,11 +78,6 @@ PrimitiveType NativeToPrimitiveType<double>() {
return F64; return F64;
} }
template <>
PrimitiveType NativeToPrimitiveType<bfloat16>() {
return BF16;
}
template <> template <>
PrimitiveType NativeToPrimitiveType<half>() { PrimitiveType NativeToPrimitiveType<half>() {
return F16; return F16;
@ -94,7 +89,7 @@ PrimitiveType NativeToPrimitiveType<complex64>() {
} }
bool IsFloatingPointType(PrimitiveType type) { bool IsFloatingPointType(PrimitiveType type) {
return type == F16 || type == F32 || type == F64 || type == BF16; return type == F16 || type == F32 || type == F64;
} }
bool IsComplexType(PrimitiveType type) { return type == C64; } bool IsComplexType(PrimitiveType type) { return type == C64; }
@ -123,7 +118,6 @@ int BitWidth(PrimitiveType type) {
case S16: case S16:
case U16: case U16:
case F16: case F16:
case BF16:
return 16; return 16;
case U32: case U32:

View File

@ -77,8 +77,6 @@ template <>
PrimitiveType NativeToPrimitiveType<double>(); PrimitiveType NativeToPrimitiveType<double>();
template <> template <>
PrimitiveType NativeToPrimitiveType<half>(); PrimitiveType NativeToPrimitiveType<half>();
template <>
PrimitiveType NativeToPrimitiveType<bfloat16>();
// Complex // Complex
template <> template <>
@ -169,11 +167,6 @@ struct PrimitiveTypeToNative<F16> {
using type = half; using type = half;
}; };
template <>
struct PrimitiveTypeToNative<BF16> {
using type = bfloat16;
};
// Complex // Complex
template <> template <>
struct PrimitiveTypeToNative<C64> { struct PrimitiveTypeToNative<C64> {

View File

@ -90,8 +90,6 @@ cc_library(
":shape_inference", ":shape_inference",
"//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
@ -1780,6 +1778,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
) )
@ -1850,6 +1849,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
) )

View File

@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/backend.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <utility> #include <utility>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/platform_util.h"

View File

@ -497,19 +497,19 @@ Status GatherComputationsByAllocationType(
std::vector<const HloComputation*>* global_computations) { std::vector<const HloComputation*>* global_computations) {
// Create a worklist of computations paired with whether the allocation must // Create a worklist of computations paired with whether the allocation must
// be thread-local. // be thread-local.
std::deque<std::pair<const HloComputation*, bool>> worklist; std::deque<std::pair<HloComputation*, bool>> worklist;
worklist.push_back(std::make_pair(module->entry_computation(), worklist.push_back(std::make_pair(module->entry_computation(),
/*is_thread_local*/ false)); /*is_thread_local*/ false));
// Sets for quickly checking membership. Computations are returned in vectors // Sets for quickly checking membership. Computations are returned in vectors
// for stable iteration. // for stable iteration.
FlatSet<const HloComputation*> thread_local_set; FlatSet<HloComputation*> thread_local_set;
FlatSet<const HloComputation*> global_set; FlatSet<HloComputation*> global_set;
while (!worklist.empty()) { while (!worklist.empty()) {
auto worklist_front = worklist.front(); auto worklist_front = worklist.front();
worklist.pop_front(); worklist.pop_front();
const HloComputation* computation = worklist_front.first; HloComputation* computation = worklist_front.first;
bool is_thread_local = worklist_front.second; bool is_thread_local = worklist_front.second;
bool in_thread_local_set = thread_local_set.count(computation) > 0; bool in_thread_local_set = thread_local_set.count(computation) > 0;
bool in_global_set = global_set.count(computation) > 0; bool in_global_set = global_set.count(computation) > 0;
@ -653,7 +653,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
} }
if (allow_input_output_aliasing_ && allocation->maybe_live_out()) { if (allow_input_output_aliasing_ && allocation->maybe_live_out()) {
const HloComputation* entry_computation = HloComputation* entry_computation =
assignment->module_->entry_computation(); assignment->module_->entry_computation();
for (auto param : entry_computation->parameter_instructions()) { for (auto param : entry_computation->parameter_instructions()) {
for (auto& param_buffer : for (auto& param_buffer :
@ -819,6 +819,17 @@ Status BufferAssigner::AssignBuffersForComputation(
continue; continue;
} }
if (instruction->opcode() == HloOpcode::kRecv) {
// Make sure that recv operations get a new unique allocation so that
// don't share their buffer with any other operations.
BufferAllocation* allocation = assignment->NewAllocation(
*buffer, buffer_size, is_thread_local, /*is_reusable=*/false);
allocation_indices.push_back(allocation->index());
VLOG(3) << "New allocation #" << allocation->index()
<< " for recv: " << *buffer;
continue;
}
if (ShapeUtil::IsTuple(buffer->shape())) { if (ShapeUtil::IsTuple(buffer->shape())) {
// TODO(b/34669761): Don't reuse tuple buffers because the GPU backend // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend
// assumes longer buffer liveness than indicated by the analysis. // assumes longer buffer liveness than indicated by the analysis.

View File

@ -280,7 +280,6 @@ cc_library(
srcs = ["dot_op_emitter.cc"], srcs = ["dot_op_emitter.cc"],
hdrs = ["dot_op_emitter.h"], hdrs = ["dot_op_emitter.h"],
deps = [ deps = [
":cpu_options",
":cpu_runtime", ":cpu_runtime",
":ir_emission_utils", ":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
@ -291,10 +290,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:ir_array",
"//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:vector_support_library",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@llvm//:core", "@llvm//:core",
], ],
@ -720,7 +717,6 @@ cc_library(
hdrs = ["cpu_options.h"], hdrs = ["cpu_options.h"],
deps = [ deps = [
"//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/core:lib",
], ],
) )

View File

@ -15,14 +15,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/core/lib/strings/numbers.h"
namespace { namespace {
const char* const kXlaParallelCpuOption = "xla_cpu_parallel"; const char* const kXlaParallelCpuOption = "xla_cpu_parallel";
const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size";
const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce";
const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
} // namespace } // namespace
@ -48,19 +45,6 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config) {
return extra_options_map.count(kXlaOptimizeForSizeCpuOption) > 0; return extra_options_map.count(kXlaOptimizeForSizeCpuOption) > 0;
} }
tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
const HloModuleConfig& config) {
const auto& extra_options_map =
config.debug_options().xla_backend_extra_options();
auto it = extra_options_map.find(kLlvmIrDotTilingFactor);
int64 tiling_factor;
if (it != extra_options_map.end() &&
tensorflow::strings::safe_strto64(it->second, &tiling_factor)) {
return tiling_factor;
}
return tensorflow::gtl::nullopt;
}
} // namespace options } // namespace options
} // namespace cpu } // namespace cpu
} // namespace xla } // namespace xla

View File

@ -27,8 +27,6 @@ namespace options {
bool CpuParallelBackendRequested(const HloModuleConfig& config); bool CpuParallelBackendRequested(const HloModuleConfig& config);
bool OptimizeForSizeRequested(const HloModuleConfig& config); bool OptimizeForSizeRequested(const HloModuleConfig& config);
bool VectorizedReduceDisabled(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config);
tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
const HloModuleConfig& config);
} // namespace options } // namespace options
} // namespace cpu } // namespace cpu

View File

@ -12,13 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <tuple> #include <tuple>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"

View File

@ -25,9 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
@ -40,450 +38,6 @@ using llvm_ir::SetToFirstInsertPoint;
namespace cpu { namespace cpu {
namespace {
// Loads a tile of values from a 2D tensor.
class TileLoader {
public:
// Constructs a TileLoader that will load a tile consisting of
// `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at
// `major_dim_offset` in the major dimension. The tile size along the minor
// dimension is the vector size, and that is implicitly determined by `vsl`.
TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder,
llvm::Value* matrix, int64 matrix_size_along_minor_dim,
llvm::Value* major_dim_offset, int64 tile_size_along_major_dim)
: vsl_(vsl) {
pointers_.reserve(tile_size_along_major_dim);
for (int64 i = 0; i < tile_size_along_major_dim; i++) {
llvm::Value* total_offset = ir_builder->CreateMul(
ir_builder->getInt64(matrix_size_along_minor_dim),
ir_builder->CreateAdd(ir_builder->getInt64(i), major_dim_offset));
pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset));
}
}
// Load a tile consisting of `tile_size_along_major_dim_` vectors starting at
// `major_dim_offset_` in the major dimension and `minor_dim_offset` in the
// minor dimension.
std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const {
std::vector<llvm::Value*> result;
result.reserve(pointers_.size());
for (const auto& pointer : pointers_) {
result.push_back(vsl_->LoadVector(pointer, minor_dim_offset));
}
return result;
}
private:
VectorSupportLibrary* vsl_;
std::vector<llvm::Value*> pointers_;
};
// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the
// layout of the vector does not matter). This implementation uses a tiling
// scheme to improve performance.
//
// We logically separate the LHS matrix into four segments:
//
// +----------------------+---+
// | | |
// | | |
// | A | B |
// | | |
// | | |
// | | |
// +----------------------+---+
// | C | D |
// +----------------------+---+
//
// where A is the largest submatrix of the LHS that can be evenly dividied into
// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
//
// +---+---+---+---+ +--+--+--+--+
// |M00|M10|M20|M30| |V0|V1|V2|V3|
// +---+---+---+---+ +--+--+--+--+
// |M01|M11|M21|M31| and |V0|V1|V2|V3|
// +---+---+---+---+ +--+--+--+--+
// |M02|M12|M22|M32| |V0|V1|V2|V3|
// +---+---+---+---+ +--+--+--+--+
// |M03|M13|M23|M33| |V0|V1|V2|V3|
// +---+---+---+---+ +--+--+--+--+
//
// (Legend: rows are horizontal and columns are vertical; and each column is one
// llvm::Value of a vector type)
//
// where:
//
// a. The left tile is from the column major left matrix.
// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3]
// vector loaded from the RHS vector.
//
// As we iterate through the column dimension, we compute the change to the
// result vector by an elementwise multiplication between the two tiles above
// followed by a reduction along the major dimension:
//
// +-----------------------------------+
// | M00*V0 + M10*V1 + M20*V2 + M30*V3 |
// +-----------------------------------+
// | M01*V0 + M11*V1 + M21*V2 + M31*V3 |
// Result[R:R+4] += +-----------------------------------+
// | M02*V0 + M12*V1 + M22*V2 + M32*V3 |
// +-----------------------------------+
// | M03*V0 + M13*V1 + M23*V2 + M33*V3 |
// +-----------------------------------+
//
// Where R is the starting row for the tile.
//
// We have an inner epilogue loop to deal with the "C" submatrix and an outer
// epilogue loop to deal with the B,D submarix.
//
// TODO(sanjoy): We should investigate if using gather loads and scatter stores
// can be used here have the same inner loop for both column-major and row-major
// matrix-vector products.
class ColumnMajorMatrixVectorProductEmitter {
public:
ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type,
int64 tile_rows, int64 tile_cols,
int64 m, int64 k, llvm::Value* lhs,
llvm::Value* rhs, llvm::Value* result,
llvm::IRBuilder<>* ir_builder)
: scalar_type_(scalar_type),
tile_rows_(tile_rows),
tile_cols_(tile_cols),
m_(m),
k_(k),
lhs_(lhs),
rhs_(rhs),
result_(result),
ir_builder_(ir_builder),
ksl_(ir_builder_),
vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") {
CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows_)));
}
void Emit();
private:
void EmitOuterLoopBody(llvm::Value* column, int64 column_count,
bool is_first_column);
TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) {
return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
/*matrix_size_along_minor_dim=*/m_,
/*major_dim_offset=*/column_start,
/*tile_size_along_major_dim=*/column_count);
}
// Load a tile of values from the RHS. For the RHS a "tile" is a contiguous
// sequnce of `count` values, each one broadcasted to the vector width.
std::vector<llvm::Value*> LoadRhsTile(llvm::Value* offset, int64 count) {
llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset);
std::vector<llvm::Value*> result;
result.reserve(count);
for (int64 i = 0; i < count; i++) {
result.push_back(vsl_.LoadBroadcast(base_pointer, i));
}
return result;
}
void EmitInnerLoopTiled(TileLoader* lhs_tile_loader,
const std::vector<llvm::Value*>& rhs_tile,
int64 columns, bool is_first_column);
void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns,
bool is_first_tiled_column);
PrimitiveType scalar_type_;
int64 tile_rows_;
int64 tile_cols_;
int64 m_;
int64 k_;
llvm::Value* lhs_;
llvm::Value* rhs_;
llvm::Value* result_;
llvm::IRBuilder<>* ir_builder_;
KernelSupportLibrary ksl_;
VectorSupportLibrary vsl_;
};
void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody(
llvm::Value* column, int64 column_count, bool is_first_column) {
TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column,
/*column_count=*/column_count);
std::vector<llvm::Value*> rhs_tile =
LoadRhsTile(column, /*count=*/column_count);
EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile,
/*columns=*/column_count, is_first_column);
EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column);
}
void ColumnMajorMatrixVectorProductEmitter::Emit() {
// See the comment on the class declaration for the algorithm used here.
int64 column_remainder = k_ % tile_cols_;
int64 column_limit = k_ - column_remainder;
ksl_.For("dot.outer.tiled",
/*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_,
[&](llvm::Value* column, bool is_first_column) {
EmitOuterLoopBody(column, tile_cols_, is_first_column);
});
if (column_remainder != 0) {
EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder,
column_limit == 0);
}
}
void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
TileLoader* lhs_tile_loader, const std::vector<llvm::Value*>& rhs_tile,
int64 columns, bool is_first_column) {
int64 row_limit = m_ - (m_ % tile_rows_);
ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
/*step=*/tile_rows_, [&](llvm::Value* row) {
std::vector<llvm::Value*> lhs_tile =
lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row);
llvm::Value* accumulator = is_first_column
? vsl_.GetZeroVector()
: vsl_.LoadVector(result_, row);
for (int i = 0; i < columns; i++) {
accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator);
}
vsl_.StoreVector(accumulator, result_, row);
});
}
void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) {
int64 row_start = m_ - (m_ % tile_rows_);
if (row_start == m_) {
return;
}
llvm::Value* columns_llvm = ir_builder_->getInt64(columns);
// for (col = current_tile_col; col < (columns + current_tile_col); col++)
// for (row = row_start, row < m_; row++) {
// result[row] += lhs[row, col] * rhs[col]
// // Also take into account that if col is 0 then result[row] is not
// // initialized.
// }
ksl_.For(
"dot.inner.epilg.outer", /*start=*/current_tile_col,
/*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col),
/*step=*/1, /*peel_first_iteration=*/false,
[&](llvm::Value* col, llvm::Value* is_first_scalar_col) {
llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col);
llvm::Value* total_offset =
ir_builder_->CreateMul(col, ir_builder_->getInt64(m_));
llvm::Value* lhs_base_pointer =
vsl_.ComputeOffsetPointer(lhs_, total_offset);
ksl_.For(
"dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_,
/*step=*/1, [&](llvm::Value* scalar_row) {
llvm::Value* product = vsl_.Mul(
vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element);
llvm::Value* setting_result_first_time = ir_builder_->CreateAnd(
is_first_scalar_col,
ir_builder_->getInt1(is_first_tiled_column));
ksl_.If(
setting_result_first_time,
[&]() { vsl_.StoreScalar(product, result_, scalar_row); },
[&]() {
vsl_.StoreScalar(
vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product),
result_, scalar_row);
});
});
});
}
// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the
// layout of the vector does not matter). This implementation uses a tiling
// scheme to improve performance.
//
// We logically separate the LHS matrix into four segments:
//
// +----------------------+---+
// | | |
// | | |
// | A | B |
// | | |
// | | |
// | | |
// +----------------------+---+
// | C | D |
// +----------------------+---+
//
// where A is the largest submatrix of the LHS that can be evenly dividied into
// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
//
// +---+---+---+---+
// |M00|M10|M20|M30|
// +---+---+---+---+ +--+--+--+--+
// |M01|M11|M21|M31| and |V0|V1|V2|V3|
// +---+---+---+---+ +--+--+--+--+
// |M02|M12|M22|M32|
// +---+---+---+---+
// |M03|M13|M23|M33|
// +---+---+---+---+
//
// (Legend: rows are horizontal and columns are vertical; and each row is one
// llvm::Value of a vector type)
//
// where:
//
// a. The left tile is loaded from the row major left matrix.
// b. The right vector is loaded from the RHS vector.
//
// We keep 4 vector accumulators accumulating the following four vector
// expressions as we iterate over the row dimension:
//
// +------+------+------+------+
// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4)
// +------+------+------+------+
//
// In the end we do a horizontal reduction over these 4 vector accumulators to
// get 4 values in the result vector.
//
// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer
// epilogue loop to deal with the C,D submatrix.
class RowMajorMatrixVectorProductEmitter {
public:
RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows,
int64 tile_cols, int64 m, int64 k,
llvm::Value* lhs, llvm::Value* rhs,
llvm::Value* result,
llvm::IRBuilder<>* ir_builder)
: scalar_type_(scalar_type),
tile_rows_(tile_rows),
tile_cols_(tile_cols),
m_(m),
k_(k),
lhs_(lhs),
rhs_(rhs),
result_(result),
ir_builder_(ir_builder),
ksl_(ir_builder_),
vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") {
CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols_)));
}
void Emit();
private:
TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) {
return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
/*matrix_size_along_minor_dim=*/k_,
/*major_dim_offset=*/row_start,
/*tile_size_along_major_dim=*/row_count);
}
void EmitOuterLoopBody(llvm::Value* row, int64 row_count);
void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows,
std::vector<VectorVariable>* vector_accumulators);
void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows,
std::vector<ScalarVariable>* scalar_accumulators);
PrimitiveType scalar_type_;
int64 tile_rows_;
int64 tile_cols_;
int64 m_;
int64 k_;
llvm::Value* lhs_;
llvm::Value* rhs_;
llvm::Value* result_;
llvm::IRBuilder<>* ir_builder_;
KernelSupportLibrary ksl_;
VectorSupportLibrary vsl_;
};
void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
int64 row_count) {
TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row,
/*row_count=*/row_count);
std::vector<VectorVariable> vector_accumulators;
std::vector<ScalarVariable> scalar_accumulators;
for (int i = 0; i < row_count; i++) {
vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector());
scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar());
}
EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count,
&vector_accumulators);
EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count,
&scalar_accumulators);
for (int i = 0; i < row_count; i++) {
llvm::Value* result_value =
vsl_.Add(vsl_.AddReduce(vector_accumulators[i].Get()),
scalar_accumulators[i].Get());
llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row);
vsl_.StoreScalar(result_value, result_, offset);
}
}
void RowMajorMatrixVectorProductEmitter::Emit() {
// See the comment on the class declaration for the algorithm used here.
int64 row_remainder = m_ % tile_rows_;
int64 row_limit = m_ - row_remainder;
ksl_.For("dot.outer.tiled",
/*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_,
[&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); });
if (row_remainder != 0) {
EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder);
}
}
void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
TileLoader* lhs_tile_loader, int64 rows,
std::vector<VectorVariable>* vector_accumulators) {
int64 column_limit = k_ - (k_ % tile_cols_);
ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
/*step=*/tile_cols_, [&](llvm::Value* col) {
std::vector<llvm::Value*> lhs_tile =
lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col);
llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
for (int i = 0; i < rows; i++) {
llvm::Value* old_sum = (*vector_accumulators)[i].Get();
(*vector_accumulators)[i].Set(
vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i])));
}
});
}
void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
llvm::Value* current_tile_row, int64 rows,
std::vector<ScalarVariable>* scalar_accumulators) {
int64 column_start = k_ - (k_ % tile_cols_);
if (column_start == k_) {
return;
}
for (int r = 0; r < rows; r++) {
llvm::Value* total_offset = ir_builder_->CreateMul(
ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row),
ir_builder_->getInt64(k_));
llvm::Value* lhs_base_pointer =
vsl_.ComputeOffsetPointer(lhs_, total_offset);
ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_,
/*step=*/1, [&](llvm::Value* scalar_col) {
llvm::Value* product =
vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
vsl_.LoadScalar(rhs_, scalar_col));
llvm::Value* old_value = (*scalar_accumulators)[r].Get();
(*scalar_accumulators)[r].Set(vsl_.Add(old_value, product));
});
}
}
} // namespace
DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
bool transpose_rhs, bool transpose_rhs,
const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& target_array,
@ -518,93 +72,6 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; }
bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
if (dot_.shape().dimensions_size() != 2 ||
ProfitableToImplementDotInUntiledLlvmIr(dot_) ==
DotInLlvmIrProfitable::kYes) {
return false;
}
if (!primitive_util::IsFloatingPointType(dot_.shape().element_type()) &&
!primitive_util::IsIntegralType(dot_.shape().element_type())) {
return false;
}
MatMultDims mat_mult_dims = GetMatMultDims();
bool is_column_major_matrix_vector = false;
bool is_row_major_matrix_vector = false;
int64 m, k;
bool swap_operands;
if (mat_mult_dims.m == 1) {
bool rhs_effectively_row_major =
transpose_rhs_ ^ !mat_mult_dims.rhs_column_major;
if (rhs_effectively_row_major) {
k = mat_mult_dims.k;
m = mat_mult_dims.n;
is_column_major_matrix_vector = true;
swap_operands = true;
} else {
k = mat_mult_dims.k;
m = mat_mult_dims.n;
is_row_major_matrix_vector = true;
swap_operands = true;
}
}
if (mat_mult_dims.n == 1) {
bool lhs_effectively_column_major =
transpose_lhs_ ^ mat_mult_dims.lhs_column_major;
if (lhs_effectively_column_major) {
m = mat_mult_dims.m;
k = mat_mult_dims.k;
is_column_major_matrix_vector = true;
swap_operands = false;
} else {
m = mat_mult_dims.m;
k = mat_mult_dims.k;
is_row_major_matrix_vector = true;
swap_operands = false;
}
}
if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) {
return false;
}
int64 tiling_factor = GetGemvTilingFactor();
CHECK_GT(tiling_factor, 0);
if (is_column_major_matrix_vector) {
VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
<< " and k = " << k;
ColumnMajorMatrixVectorProductEmitter emitter(
dot_.shape().element_type(), /*tile_rows=*/8,
/*tile_cols=*/tiling_factor, m, k,
swap_operands ? rhs_array_.GetBasePointer()
: lhs_array_.GetBasePointer(),
swap_operands ? lhs_array_.GetBasePointer()
: rhs_array_.GetBasePointer(),
target_array_.GetBasePointer(), ir_builder_);
emitter.Emit();
} else {
VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
<< " and k = " << k;
RowMajorMatrixVectorProductEmitter emitter(
dot_.shape().element_type(), /*tile_rows=*/tiling_factor,
/*tile_cols=*/8, m, k,
swap_operands ? rhs_array_.GetBasePointer()
: lhs_array_.GetBasePointer(),
swap_operands ? lhs_array_.GetBasePointer()
: rhs_array_.GetBasePointer(),
target_array_.GetBasePointer(), ir_builder_);
emitter.Emit();
}
return true;
}
tensorflow::Status DotOpEmitter::Emit() { tensorflow::Status DotOpEmitter::Emit() {
// The dot operation performs a sum of products over dimension 0 of the left // The dot operation performs a sum of products over dimension 0 of the left
// hand side operand and dimension 1 of the right hand side operand. // hand side operand and dimension 1 of the right hand side operand.
@ -638,10 +105,6 @@ tensorflow::Status DotOpEmitter::Emit() {
return EmitScalarDot(); return EmitScalarDot();
} }
if (EmitLlvmIrDotIfProfitable()) {
return Status::OK();
}
if (PotentiallyImplementedAsEigenDot(dot_)) { if (PotentiallyImplementedAsEigenDot(dot_)) {
return EmitCallToRuntime(); return EmitCallToRuntime();
} }
@ -877,17 +340,22 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
// //
// Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'. // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
MatMultDims mat_mult_dims = GetMatMultDims(); const Shape& lhs_shape = lhs_array_.GetShape();
const Shape& rhs_shape = rhs_array_.GetShape();
CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major); CHECK(LayoutUtil::Equal(lhs_shape.layout(), rhs_shape.layout()));
int64 m = lhs_shape.dimensions(transpose_lhs_ ? 1 : 0);
int64 k = lhs_shape.dimensions(transpose_lhs_ ? 0 : 1);
int64 n = rhs_shape.dimensions(transpose_rhs_ ? 0 : 1);
const llvm_ir::IrArray* lhs = &lhs_array_; const llvm_ir::IrArray* lhs = &lhs_array_;
const llvm_ir::IrArray* rhs = &rhs_array_; const llvm_ir::IrArray* rhs = &rhs_array_;
bool transpose_lhs = transpose_lhs_; bool transpose_lhs = transpose_lhs_;
bool transpose_rhs = transpose_rhs_; bool transpose_rhs = transpose_rhs_;
if (!mat_mult_dims.lhs_column_major) { bool is_column_major = lhs_shape.layout().minor_to_major(0) == 0;
std::swap(mat_mult_dims.m, mat_mult_dims.n); if (!is_column_major) {
std::swap(m, n);
std::swap(lhs, rhs); std::swap(lhs, rhs);
std::swap(transpose_lhs, transpose_rhs); std::swap(transpose_lhs, transpose_rhs);
} }
@ -899,27 +367,12 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
float_ptr_type), float_ptr_type),
ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type), ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type), ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
ir_builder_->getInt64(mat_mult_dims.m), ir_builder_->getInt64(m), ir_builder_->getInt64(n),
ir_builder_->getInt64(mat_mult_dims.n), ir_builder_->getInt64(k), ir_builder_->getInt32(transpose_lhs),
ir_builder_->getInt64(mat_mult_dims.k),
ir_builder_->getInt32(transpose_lhs),
ir_builder_->getInt32(transpose_rhs)}); ir_builder_->getInt32(transpose_rhs)});
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
CHECK_EQ(dot_.shape().dimensions_size(), 2);
const Shape& lhs_shape = lhs_array_.GetShape();
const Shape& rhs_shape = rhs_array_.GetShape();
return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0),
lhs_shape.dimensions(transpose_lhs_ ? 0 : 1),
rhs_shape.dimensions(transpose_rhs_ ? 0 : 1),
lhs_shape.layout().minor_to_major(0) == 0,
rhs_shape.layout().minor_to_major(0) == 0};
}
llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest(
llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array,
int64 reduction_dimension, tensorflow::StringPiece name_suffix) { int64 reduction_dimension, tensorflow::StringPiece name_suffix) {

View File

@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_
#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IRBuilder.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
@ -60,10 +59,6 @@ class DotOpEmitter {
// LHS and RHS) and store the results in the target. // LHS and RHS) and store the results in the target.
tensorflow::Status EmitScalarDot(); tensorflow::Status EmitScalarDot();
// Emit an LLVM IR implementation of the dot operation if we can. Returns
// true if an LLVM IR implementation was emitted.
bool EmitLlvmIrDotIfProfitable();
// Emits a call to the CPU runtime to perform the matrix multiply. // Emits a call to the CPU runtime to perform the matrix multiply.
tensorflow::Status EmitCallToRuntime(); tensorflow::Status EmitCallToRuntime();
@ -82,38 +77,6 @@ class DotOpEmitter {
// no padding, and a rank of two. // no padding, and a rank of two.
bool ShapesAreLegalForRuntimeDot() const; bool ShapesAreLegalForRuntimeDot() const;
// Represents the dimensions of a matrix-matrix multiply operation.
struct MatMultDims {
// The number of rows in the LHS.
int64 m;
// The number of columns in the LHS, which is also must be equal to the
// number of rows in the RHS.
int64 k;
// The number of columns on the RHS.
int64 n;
// True if the LHS matrix column major.
bool lhs_column_major;
// True if the RHS matrix column major.
bool rhs_column_major;
};
// Get the MatMultDims instance for the dot product this DotOpEmitter
// represents. Precondition: the dot is of rank 2 (and thus its operands are
// of rank 2 as well).
MatMultDims GetMatMultDims() const;
// When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
// registers.
int64 GetGemvTilingFactor() const {
const int64 kDefaultTilingFactor = 8;
return options::LlvmIrGemvTilingFactor(hlo_module_config_)
.value_or(kDefaultTilingFactor);
}
const HloInstruction& dot_; const HloInstruction& dot_;
const bool transpose_lhs_; const bool transpose_lhs_;
const bool transpose_rhs_; const bool transpose_rhs_;

View File

@ -105,9 +105,7 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
return false; return false;
} }
if (ProfitableToImplementDotInUntiledLlvmIr(hlo) == if (ProfitableToImplementDotInLlvmIr(hlo) == DotInLlvmIrProfitable::kYes) {
DotInLlvmIrProfitable::kYes ||
ProfitableToImplementDotInTiledLlvmIr(hlo)) {
return false; return false;
} }
@ -138,7 +136,7 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
return false; return false;
} }
DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr(
const HloInstruction& dot) { const HloInstruction& dot) {
if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) { if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) {
const Shape& result_shape = dot.shape(); const Shape& result_shape = dot.shape();
@ -180,16 +178,5 @@ DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
return DotInLlvmIrProfitable::kNo; return DotInLlvmIrProfitable::kNo;
} }
bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) {
// Any Matrix-Vector product of floating point or integral type, or
// a transpose-dot fusion of the same can be lowered to a tiled LLVM
// IR implementation.
const Shape& shape = dot.shape();
return shape.dimensions_size() == 2 &&
(shape.dimensions(0) == 1 || shape.dimensions(1) == 1) &&
(primitive_util::IsFloatingPointType(shape.element_type()) ||
primitive_util::IsIntegralType(shape.element_type()));
}
} // namespace cpu } // namespace cpu
} // namespace xla } // namespace xla

View File

@ -29,21 +29,16 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& dot);
enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs }; enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs };
// Returns a value to indicate if (and under what conditions) will lowering // Returns a value to indicate if (and under what conditions) will lowering
// |dot| as a untiled LLVM IR dot operation be profitable over calling into // |dot| as a pure LLVM IR dot operation be profitable over calling into Eigen.
// Eigen or emitting a tiled LLVM IR implementation. Possible return values // Possible return values are:
// are:
// //
// * DotInLlvmIrProfitable::kYes - always profitable. // * DotInLlvmIrProfitable::kYes - always profitable.
// * DotInLlvmIrProfitable::kNo - never profitable. // * DotInLlvmIrProfitable::kNo - never profitable.
// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make // * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make
// the Rhs layout column major. // the Rhs layout column major.
DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr(
const HloInstruction& dot); const HloInstruction& dot);
// Returns true to indicate that we can generate a tiled LLVM IR implementation
// for |dot|.
bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot);
} // namespace cpu } // namespace cpu
} // namespace xla } // namespace xla

View File

@ -1983,11 +1983,6 @@ Status IrEmitter::HandleSend(HloInstruction* send) {
return Unimplemented("Send is not implemented on CPU. See b/33942983."); return Unimplemented("Send is not implemented on CPU. See b/33942983.");
} }
Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
// TODO(b/33942983): Support Send/Recv on CPU.
return Unimplemented("Send-done is not implemented on CPU. See b/33942983.");
}
Status IrEmitter::HandleSlice(HloInstruction* slice) { Status IrEmitter::HandleSlice(HloInstruction* slice) {
VLOG(2) << "HandleSlice: " << slice->ToString(); VLOG(2) << "HandleSlice: " << slice->ToString();
auto operand = slice->operand(0); auto operand = slice->operand(0);
@ -2153,11 +2148,6 @@ Status IrEmitter::HandleRecv(HloInstruction* recv) {
return Unimplemented("Recv is not implemented on CPU. See b/33942983."); return Unimplemented("Recv is not implemented on CPU. See b/33942983.");
} }
Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
// TODO(b/33942983): Support Send/Recv on CPU.
return Unimplemented("Recv-done is not implemented on CPU. See b/33942983.");
}
Status IrEmitter::HandlePad(HloInstruction* pad) { Status IrEmitter::HandlePad(HloInstruction* pad) {
// CPU backend does not properly handle negative padding but this is ok // CPU backend does not properly handle negative padding but this is ok
// because negative padding should be removed by the algebraic simplifier. // because negative padding should be removed by the algebraic simplifier.

View File

@ -171,13 +171,11 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleReduceWindow(HloInstruction* reduce_window) override; Status HandleReduceWindow(HloInstruction* reduce_window) override;
Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override; Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override;
Status HandleSend(HloInstruction* send) override; Status HandleSend(HloInstruction* send) override;
Status HandleSendDone(HloInstruction* send_done) override;
Status HandleSlice(HloInstruction* slice) override; Status HandleSlice(HloInstruction* slice) override;
Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
Status HandleDynamicUpdateSlice( Status HandleDynamicUpdateSlice(
HloInstruction* dynamic_update_slice) override; HloInstruction* dynamic_update_slice) override;
Status HandleRecv(HloInstruction* recv) override; Status HandleRecv(HloInstruction* recv) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandlePad(HloInstruction* pad) override; Status HandlePad(HloInstruction* pad) override;
Status HandleTuple(HloInstruction* tuple) override; Status HandleTuple(HloInstruction* tuple) override;
Status HandleMap(HloInstruction* map) override; Status HandleMap(HloInstruction* map) override;

View File

@ -51,7 +51,7 @@ Status CpuLayoutAssignment::AddBackendConstraints(
tensorflow::gtl::FlatMap<const HloInstruction*, bool> tensorflow::gtl::FlatMap<const HloInstruction*, bool>
should_make_rhs_col_major_cache; should_make_rhs_col_major_cache;
auto should_make_rhs_col_major = [&](const HloInstruction& instruction) { auto should_make_rhs_col_major = [&](const HloInstruction& instruction) {
if (ProfitableToImplementDotInUntiledLlvmIr(instruction) != if (ProfitableToImplementDotInLlvmIr(instruction) !=
DotInLlvmIrProfitable::kWithColumnMajorRhs) { DotInLlvmIrProfitable::kWithColumnMajorRhs) {
return false; return false;
} }
@ -68,7 +68,7 @@ Status CpuLayoutAssignment::AddBackendConstraints(
bool result = std::all_of( bool result = std::all_of(
rhs->users().begin(), rhs->users().end(), [&](HloInstruction* user) { rhs->users().begin(), rhs->users().end(), [&](HloInstruction* user) {
return ProfitableToImplementDotInUntiledLlvmIr(*user) == return ProfitableToImplementDotInLlvmIr(*user) ==
DotInLlvmIrProfitable::kWithColumnMajorRhs && DotInLlvmIrProfitable::kWithColumnMajorRhs &&
user->operand(0) != rhs; user->operand(0) != rhs;
}); });

View File

@ -211,11 +211,9 @@ class DfsHloVisitorBase {
virtual Status HandlePad(HloInstructionPtr hlo) = 0; virtual Status HandlePad(HloInstructionPtr hlo) = 0;
virtual Status HandleSend(HloInstructionPtr send) = 0; virtual Status HandleSend(HloInstructionPtr hlo) = 0;
virtual Status HandleSendDone(HloInstructionPtr send_done) = 0;
virtual Status HandleRecv(HloInstructionPtr recv) = 0; virtual Status HandleRecv(HloInstructionPtr hlo) = 0;
virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0;
virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0; virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0;

View File

@ -167,17 +167,11 @@ class DfsHloVisitorWithDefaultBase
Status HandleWhile(HloInstructionPtr xla_while) override { Status HandleWhile(HloInstructionPtr xla_while) override {
return DefaultAction(xla_while); return DefaultAction(xla_while);
} }
Status HandleRecv(HloInstructionPtr recv) override {
return DefaultAction(recv);
}
Status HandleRecvDone(HloInstructionPtr recv_done) override {
return DefaultAction(recv_done);
}
Status HandleSend(HloInstructionPtr send) override { Status HandleSend(HloInstructionPtr send) override {
return DefaultAction(send); return DefaultAction(send);
} }
Status HandleSendDone(HloInstructionPtr send_done) override { Status HandleRecv(HloInstructionPtr recv) override {
return DefaultAction(send_done); return DefaultAction(recv);
} }
// Invoked to inform the visitor that the traversal has completed, and that // Invoked to inform the visitor that the traversal has completed, and that

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -280,13 +279,6 @@ std::vector<AlgorithmDesc> ConvolutionThunk::GetAlgorithms(
return algorithms; return algorithms;
} }
static string AlgorithmToString(const se::dnn::AlgorithmDesc& algo) {
if (algo.tensor_ops_enabled()) {
return tensorflow::strings::StrCat(algo.algo_id(), "+TC");
}
return tensorflow::strings::StrCat(algo.algo_id());
}
tensorflow::Status ConvolutionThunk::ConvolveWithTune( tensorflow::Status ConvolutionThunk::ConvolveWithTune(
const BatchDescriptor& input_descriptor, se::DeviceMemory<float> input_data, const BatchDescriptor& input_descriptor, se::DeviceMemory<float> input_data,
const FilterDescriptor& filter_descriptor, const FilterDescriptor& filter_descriptor,
@ -311,8 +303,6 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
buffer_allocations.device_ordinal(), buffer_allocations.device_ordinal(),
buffer_allocations.memory_allocator()); buffer_allocations.memory_allocator());
se::dnn::ProfileResult profile_result; se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(algorithm)
<< " for ConvolutionThunk: " << this;
bool launch_ok = bool launch_ok =
Convolve(input_descriptor, input_data, filter_descriptor, filter_data, Convolve(input_descriptor, input_data, filter_descriptor, filter_data,
output_descriptor, output_data, convolution_descriptor, output_descriptor, output_data, convolution_descriptor,
@ -320,11 +310,6 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
&scratch_allocator, &profile_result) &scratch_allocator, &profile_result)
.ok(); .ok();
if (launch_ok && profile_result.is_valid()) { if (launch_ok && profile_result.is_valid()) {
VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm)
<< " for ConvolutionThunk " << this << " succeeded, taking "
<< profile_result.elapsed_time_in_ms()
<< "ms. (Best result: " << best_result.elapsed_time_in_ms()
<< "ms)";
if (profile_result.elapsed_time_in_ms() < if (profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) { best_result.elapsed_time_in_ms()) {
best_result = profile_result; best_result = profile_result;
@ -334,9 +319,6 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
best_result_without_scratch.elapsed_time_in_ms()) { best_result_without_scratch.elapsed_time_in_ms()) {
best_result_without_scratch = profile_result; best_result_without_scratch = profile_result;
} }
} else {
VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm)
<< " for ConvolutionThunk " << this << " failed.";
} }
} }
@ -361,8 +343,8 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
{ {
VLOG(2) << "Using convolution algorithm (" VLOG(2) << "Using convolution algorithm ("
<< AlgorithmToString(best_algorithm_.algorithm()) << ", " << best_algorithm_.algorithm().algo_id() << ", "
<< AlgorithmToString(best_algorithm_.algorithm_no_scratch()) << best_algorithm_.algorithm_no_scratch().algo_id()
<< ") for ConvolutionThunk: " << this; << ") for ConvolutionThunk: " << this;
ConvolveScratchAllocator scratch_allocator( ConvolveScratchAllocator scratch_allocator(
buffer_allocations.device_ordinal(), buffer_allocations.device_ordinal(),

View File

@ -75,7 +75,6 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/subprocess.h"
#include "tensorflow/core/platform/tracing.h"
namespace se = ::perftools::gputools; namespace se = ::perftools::gputools;
@ -88,7 +87,6 @@ namespace gpu {
namespace { namespace {
using tensorflow::port::Tracing;
using tensorflow::strings::StrCat; using tensorflow::strings::StrCat;
// Any address of a variable residing in global memory or returned by one of the // Any address of a variable residing in global memory or returned by one of the
@ -233,7 +231,6 @@ tensorflow::Status PrepareHloModuleForIrEmitting(
// code (i.e. a cubin) as a byte array. // code (i.e. a cubin) as a byte array.
StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major, StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major,
int cc_minor) { int cc_minor) {
Tracing::TraceMe annotation("Compile PTX", /*is_expensive=*/true);
const string ptxas_path = const string ptxas_path =
tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas");
VLOG(2) << "Using ptxas at " << ptxas_path; VLOG(2) << "Using ptxas at " << ptxas_path;
@ -298,15 +295,11 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec) { std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec) {
TF_RET_CHECK(stream_exec != nullptr); TF_RET_CHECK(stream_exec != nullptr);
{ TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(),
Tracing::TraceMe annotation("HLO Transforms", module->name(), stream_exec->GetDeviceDescription(),
/*is_expensive=*/true); ShapeSizeBytesFunction()));
TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), TF_RETURN_IF_ERROR(
stream_exec->GetDeviceDescription(), PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction()));
ShapeSizeBytesFunction()));
TF_RETURN_IF_ERROR(
PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction()));
}
llvm::LLVMContext llvm_context; llvm::LLVMContext llvm_context;
std::string buffer; std::string buffer;
@ -451,7 +444,6 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
std::vector<uint8> GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, std::vector<uint8> GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx,
int cc_major, int cc_major,
int cc_minor) { int cc_minor) {
Tracing::TraceMe annotation("PTX->CUBIN", /*is_expensive=*/true);
bool inserted; bool inserted;
decltype(compilation_cache_.begin()) iter; decltype(compilation_cache_.begin()) iter;
// Pointers into compilation_cache_ where the ptx and (optional) cubin are // Pointers into compilation_cache_ where the ptx and (optional) cubin are

View File

@ -128,18 +128,10 @@ Status IrEmitter::HandleSend(HloInstruction*) {
return Unimplemented("Send is not implemented on GPU"); return Unimplemented("Send is not implemented on GPU");
} }
Status IrEmitter::HandleSendDone(HloInstruction*) {
return Unimplemented("Send-Done is not implemented on GPU");
}
Status IrEmitter::HandleRecv(HloInstruction*) { Status IrEmitter::HandleRecv(HloInstruction*) {
return Unimplemented("Recv is not implemented on GPU"); return Unimplemented("Recv is not implemented on GPU");
} }
Status IrEmitter::HandleRecvDone(HloInstruction*) {
return Unimplemented("Recv-done is not implemented on GPU");
}
Status IrEmitter::HandleTuple(HloInstruction* tuple) { Status IrEmitter::HandleTuple(HloInstruction* tuple) {
std::vector<llvm::Value*> base_ptrs; std::vector<llvm::Value*> base_ptrs;
for (const HloInstruction* operand : tuple->operands()) { for (const HloInstruction* operand : tuple->operands()) {

View File

@ -84,9 +84,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleOutfeed(HloInstruction* outfeed) override;
Status HandleSort(HloInstruction* sort) override; Status HandleSort(HloInstruction* sort) override;
Status HandleSend(HloInstruction* send) override; Status HandleSend(HloInstruction* send) override;
Status HandleSendDone(HloInstruction* send_done) override;
Status HandleRecv(HloInstruction* recv) override; Status HandleRecv(HloInstruction* recv) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleParameter(HloInstruction* parameter) override; Status HandleParameter(HloInstruction* parameter) override;
Status HandleReduce(HloInstruction* reduce) override; Status HandleReduce(HloInstruction* reduce) override;
Status HandleTuple(HloInstruction* tuple) override; Status HandleTuple(HloInstruction* tuple) override;

View File

@ -60,7 +60,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
namespace xla { namespace xla {
namespace gpu { namespace gpu {
@ -489,9 +488,6 @@ StatusOr<string> CompileToPtx(llvm::Module* module,
string ptx; string ptx;
{ {
tensorflow::port::Tracing::TraceMe annotation(
"Compiling IR", llvm_ir::AsString(module->getName()),
/*is_expensive=*/true);
ScopedLoggingTimer compilation_timer( ScopedLoggingTimer compilation_timer(
"Compile module " + llvm_ir::AsString(module->getName()), "Compile module " + llvm_ir::AsString(module->getName()),
/*vlog_level=*/2); /*vlog_level=*/2);

View File

@ -337,18 +337,10 @@ Status HloCostAnalysis::HandleSend(const HloInstruction*) {
return Status::OK(); return Status::OK();
} }
Status HloCostAnalysis::HandleSendDone(const HloInstruction*) {
return Status::OK();
}
Status HloCostAnalysis::HandleRecv(const HloInstruction*) { Status HloCostAnalysis::HandleRecv(const HloInstruction*) {
return Status::OK(); return Status::OK();
} }
Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) {
return Status::OK();
}
Status HloCostAnalysis::HandleReshape(const HloInstruction*) { Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
return Status::OK(); return Status::OK();
} }

View File

@ -60,9 +60,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleReducePrecision(const HloInstruction* hlo) override; Status HandleReducePrecision(const HloInstruction* hlo) override;
Status HandleConcatenate(const HloInstruction* concatenate) override; Status HandleConcatenate(const HloInstruction* concatenate) override;
Status HandleSend(const HloInstruction* send) override; Status HandleSend(const HloInstruction* send) override;
Status HandleSendDone(const HloInstruction* send_done) override;
Status HandleRecv(const HloInstruction* recv) override; Status HandleRecv(const HloInstruction* recv) override;
Status HandleRecvDone(const HloInstruction* recv_done) override;
Status HandleConvert(const HloInstruction* convert) override; Status HandleConvert(const HloInstruction* convert) override;
Status HandleCopy(const HloInstruction* copy) override; Status HandleCopy(const HloInstruction* copy) override;
Status HandleDot(const HloInstruction* dot) override; Status HandleDot(const HloInstruction* dot) override;

View File

@ -79,12 +79,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
// Test that two identical constants with different layouts are commoned if // Test that two identical constants with different layouts are commoned if
// the pass is not layout sensitive. // the pass is not layout sensitive.
auto builder = HloComputation::Builder(TestName()); auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction( auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); /*minor_to_major=*/{0, 1})));
auto constant2 = builder.AddInstruction( auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); /*minor_to_major=*/{1, 0})));
auto add = builder.AddInstruction(HloInstruction::CreateBinary( auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2)); constant1->shape(), HloOpcode::kAdd, constant1, constant2));
@ -111,12 +111,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
// Test that two identical constants with different layouts are *not* commoned // Test that two identical constants with different layouts are *not* commoned
// if the pass is layout sensitive. // if the pass is layout sensitive.
auto builder = HloComputation::Builder(TestName()); auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction( auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); /*minor_to_major=*/{0, 1})));
auto constant2 = builder.AddInstruction( auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); /*minor_to_major=*/{1, 0})));
auto add = builder.AddInstruction(HloInstruction::CreateBinary( auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2)); constant1->shape(), HloOpcode::kAdd, constant1, constant2));

View File

@ -242,51 +242,6 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
return false; return false;
} }
bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
CHECK_EQ(send->opcode(), HloOpcode::kSend);
bool changed = false;
// Send forwards the operand value to the output tuple at {0}.
for (auto& pair : GetInstructionValueSet(send->operand(0))) {
const ShapeIndex& operand_index = pair.first;
const HloValueSet& operand_value_set = pair.second;
ShapeIndex index = {0};
for (int64 i : operand_index) {
index.push_back(i);
}
HloValueSet& value_set = GetValueSet(send, index);
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
}
}
return changed;
}
bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
bool changed = false;
// RecvDone forwards the operand value at {0} to the output.
for (auto& pair : GetInstructionValueSet(recv_done)) {
ShapeIndex& index = pair.first;
HloValueSet& value_set = pair.second;
ShapeIndex operand_index = {0};
for (int64 i : index) {
operand_index.push_back(i);
}
const HloValueSet& operand_value_set =
GetValueSet(recv_done->operand(0), operand_index);
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
}
}
return changed;
}
bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
CHECK_EQ(call->opcode(), HloOpcode::kCall); CHECK_EQ(call->opcode(), HloOpcode::kCall);
InstructionValueSet& value_set = GetInstructionValueSet(call); InstructionValueSet& value_set = GetInstructionValueSet(call);
@ -474,10 +429,6 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
return UpdateCallValueSet(instruction); return UpdateCallValueSet(instruction);
case HloOpcode::kWhile: case HloOpcode::kWhile:
return UpdateWhileValueSet(instruction); return UpdateWhileValueSet(instruction);
case HloOpcode::kSend:
return UpdateSendValueSet(instruction);
case HloOpcode::kRecvDone:
return UpdateRecvDoneValueSet(instruction);
default: default:
// Instruction does not forward HloValues (it defines all values in its // Instruction does not forward HloValues (it defines all values in its
// output). No update is necessary. // output). No update is necessary.
@ -586,12 +537,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
GetValueSet(instruction, /*index=*/{}).AddValue(value); GetValueSet(instruction, /*index=*/{}).AddValue(value);
}; };
// Lambda to set the value set at the given index of the output.
auto define_value_at = [this, &instruction](const ShapeIndex& index) {
HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
GetValueSet(instruction, index).AddValue(value);
};
switch (instruction->opcode()) { switch (instruction->opcode()) {
case HloOpcode::kBitcast: case HloOpcode::kBitcast:
if (bitcast_defines_value_) { if (bitcast_defines_value_) {
@ -632,16 +577,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// values flow from their operands. // values flow from their operands.
define_top_level_only(); define_top_level_only();
break; break;
case HloOpcode::kRecvDone:
// RecvDone aliases its input tuple element {0}, therefore does not
// define any values.
break;
case HloOpcode::kSend:
// Send produces a tuple of {aliased operand, U32 context}, therefore
// only defines the top-level tuple and the tuple element at {1}.
define_value_at(/*index=*/{});
define_value_at(/*index=*/{1});
break;
default: default:
define_all_values(); define_all_values();
break; break;

View File

@ -146,9 +146,7 @@ class HloDataflowAnalysis {
bool UpdateCopyValueSet(HloInstruction* copy); bool UpdateCopyValueSet(HloInstruction* copy);
bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateGetTupleElementValueSet(HloInstruction* gte);
bool UpdateParameterValueSet(HloInstruction* parameter); bool UpdateParameterValueSet(HloInstruction* parameter);
bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
bool UpdateSelectValueSet(HloInstruction* select); bool UpdateSelectValueSet(HloInstruction* select);
bool UpdateSendValueSet(HloInstruction* send);
bool UpdateTupleValueSet(HloInstruction* tuple); bool UpdateTupleValueSet(HloInstruction* tuple);
bool UpdateWhileValueSet(HloInstruction* xla_while); bool UpdateWhileValueSet(HloInstruction* xla_while);

View File

@ -1139,54 +1139,6 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) {
analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module()); analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module());
} }
TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
// Test that a Send forwards its operand to the output tuple at {0}.
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
auto send = builder.AddInstruction(
HloInstruction::CreateSend(param, /*channel_id=*/0));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
EXPECT_EQ(analysis.values().size(), 4);
EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(param)));
}
TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
// Test that a RecvDone forwards its operand tuple element at {0} to the
// output.
auto builder = HloComputation::Builder(TestName());
auto recv = builder.AddInstruction(
HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
EXPECT_EQ(analysis.values().size(), 3);
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done));
EXPECT_THAT(HloValuesAt(recv_done),
UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0})));
EXPECT_TRUE(
analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
}
TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) { TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) {
// A simple chain of elementwise operations. No values should interfere. // A simple chain of elementwise operations. No values should interfere.
// //

View File

@ -1450,10 +1450,6 @@ HloEvaluator::HloEvaluator() {
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this); typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this); typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this); typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
typed_visitors_[BF16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("HloEvaluator: unhandled primitive type: BF16.");
});
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) { typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE.");
}); });

View File

@ -39,18 +39,16 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
HloEvaluator(); HloEvaluator();
// Evaluates an HLO module and an array of pointers to literals. // Evaluates an HLO module and an array of pointers to literals.
// Returns the evaluated result as a literal if successful. // Returns the evaluated result as a literal if successful.
// Precondition: The indices of arg_literals correspond to the parameter // Precondition: argument literals correspond to each input computation's
// numbers of the HLO parameters in the computation. See comment below for an // parameters in their post-ordering. See comment below for example.
// example.
StatusOr<std::unique_ptr<Literal>> Evaluate( StatusOr<std::unique_ptr<Literal>> Evaluate(
const HloModule& module, const HloModule& module,
tensorflow::gtl::ArraySlice<const Literal*> arg_literals); tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
// Evaluates an HLO computation and an array of pointers to literals. // Evaluates an HLO computation and an array of pointers to literals.
// Returns the evaluated result as a literal if successful. // Returns the evaluated result as a literal if successful.
// Precondition: The indices of arg_literals correspond to the parameter // Precondition: argument literals correspond to the input computation's
// numbers of the HLO parameters in the computation. For e.g., consider the // parameters in their post-ordering. For e.g., consider the following graph:
// following graph:
// //
// * // *
// / \ // / \
@ -59,9 +57,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// / \ // / \
// Parameter0 Constant // Parameter0 Constant
// //
// where Parameter0 has parameter_number 0 and Parameter1 has parameter_number // The input literals array will have its first literal map to Parameter0 and
// 1 in this computation. The input literals array will then have its first // the second map to Parameter1.
// literal map to Parameter0 and the second map to Parameter1.
StatusOr<std::unique_ptr<Literal>> Evaluate( StatusOr<std::unique_ptr<Literal>> Evaluate(
const HloComputation& computation, const HloComputation& computation,
tensorflow::gtl::ArraySlice<const Literal*> arg_literals); tensorflow::gtl::ArraySlice<const Literal*> arg_literals);

View File

@ -761,22 +761,12 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
string HloDotDumper::GetInstructionNodeInlinedOperands( string HloDotDumper::GetInstructionNodeInlinedOperands(
const HloInstruction* instr) { const HloInstruction* instr) {
auto stringify_constant = [](const HloInstruction* constant) { auto stringify_constant = [](const HloInstruction* constant) {
const auto& shape = constant->shape(); if (ShapeUtil::IsEffectiveScalar(constant->shape())) {
auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex(
// Print the literal value of constants with <= K elements. constant->shape(), /*linear_index=*/0);
optional<int64> elem_count; return Printf("%s (%s)", constant->literal().GetAsString(elem_idx),
if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) {
elem_count = 1;
for (int64 dim : shape.dimensions()) {
*elem_count *= dim;
}
}
if (elem_count.has_value() && *elem_count <= 8) {
return Printf("%s (%s)", constant->literal().ToString(),
ShapeUtil::HumanString(constant->shape())); ShapeUtil::HumanString(constant->shape()));
} }
// Otherwise, print e.g. "%constant.42 (s32[100])".
string constant_name; string constant_name;
if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) { if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) {
constant_name = constant->name(); constant_name = constant->name();
@ -943,9 +933,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kFusion: case HloOpcode::kFusion:
return kGray; return kGray;
case HloOpcode::kSend: case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv: case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kInfeed: case HloOpcode::kInfeed:
case HloOpcode::kOutfeed: case HloOpcode::kOutfeed:
case HloOpcode::kCrossReplicaSum: case HloOpcode::kCrossReplicaSum:
@ -1039,9 +1027,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
? "" ? ""
: StrCat("stride=", VectorString(instr->slice_strides())); : StrCat("stride=", VectorString(instr->slice_strides()));
case HloOpcode::kSend: case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv: case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
return StrCat("channel_id=", instr->channel_id()); return StrCat("channel_id=", instr->channel_id());
default: default:
return ""; return "";
@ -1303,9 +1289,7 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
auto is_displayed = [&](const HloInstruction* instr) { auto is_displayed = [&](const HloInstruction* instr) {
// Constants are displayed inline with their users; they're never omitted. // Constants are displayed inline with their users; they're never omitted.
// Nodes in subcomputations are always shown. return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant;
return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant ||
instr->parent() != root->parent();
}; };
// Make a second pass over 'nodes' to fix up the NodeFilterResults now that we // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we

View File

@ -371,50 +371,20 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
HloInstruction* operand, int64 channel_id) { HloInstruction* operand, int64 channel_id) {
// Send instruction produces a tuple of {aliased operand, U32 context}.
Shape output_shape = ShapeUtil::MakeTupleShape(
{operand->shape(), ShapeUtil::MakeShape(U32, {})});
auto instruction = auto instruction =
WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape)); WrapUnique(new HloInstruction(HloOpcode::kSend, ShapeUtil::MakeNil()));
instruction->AppendOperand(operand); instruction->AppendOperand(operand);
instruction->channel_id_ = channel_id; instruction->channel_id_ = channel_id;
return instruction; return instruction;
} }
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
HloInstruction* operand) {
CHECK(operand->opcode() == HloOpcode::kSend)
<< "SendDone must take the context operand from Send";
auto instruction = WrapUnique(
new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil()));
instruction->AppendOperand(operand);
instruction->channel_id_ = operand->channel_id();
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
const Shape& shape, int64 channel_id) { const Shape& shape, int64 channel_id) {
// Recv instruction produces a tuple of {receive buffer, U32 context}. auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRecv, shape));
Shape output_shape =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
auto instruction =
WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape));
instruction->channel_id_ = channel_id; instruction->channel_id_ = channel_id;
return instruction; return instruction;
} }
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
HloInstruction* operand) {
CHECK(operand->opcode() == HloOpcode::kRecv)
<< "RecvDone must take the context operand from Recv";
Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0);
auto instruction =
WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape));
instruction->AppendOperand(operand);
instruction->channel_id_ = operand->channel_id();
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
const Shape& shape, HloInstruction* operand, const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions) { tensorflow::gtl::ArraySlice<int64> dimensions) {
@ -938,9 +908,7 @@ RandomDistribution HloInstruction::random_distribution() const {
bool HloInstruction::HasSideEffect() const { bool HloInstruction::HasSideEffect() const {
switch (opcode_) { switch (opcode_) {
case HloOpcode::kSend: case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv: case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kInfeed: case HloOpcode::kInfeed:
case HloOpcode::kOutfeed: case HloOpcode::kOutfeed:
case HloOpcode::kTrace: case HloOpcode::kTrace:
@ -1196,9 +1164,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
new_operands[4], epsilon(), feature_index()); new_operands[4], epsilon(), feature_index());
break; break;
case HloOpcode::kRecv: case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend: case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kTrace: case HloOpcode::kTrace:
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
} }
@ -1591,10 +1557,8 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kInfeed: case HloOpcode::kInfeed:
case HloOpcode::kOutfeed: case HloOpcode::kOutfeed:
case HloOpcode::kSort: case HloOpcode::kSort:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend: case HloOpcode::kSend:
case HloOpcode::kSendDone: case HloOpcode::kRecv:
return false; return false;
} }
} }
@ -1886,13 +1850,12 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
} }
if (window_ != nullptr) { if (window_ != nullptr) {
extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); extra.push_back(window_util::ToString(*window_));
} }
if (padding_config_ != nullptr) { if (padding_config_ != nullptr) {
extra.push_back( extra.push_back(StrCat("padding=", padding_config_->ShortDebugString()));
StrCat("padding=", xla::PaddingConfigToString(*padding_config_)));
} }
if (opcode() == HloOpcode::kSlice) { if (!slice_starts_.empty() && !slice_limits_.empty()) {
std::vector<string> bounds; std::vector<string> bounds;
bounds.reserve(slice_starts_.size()); bounds.reserve(slice_starts_.size());
const bool omit_stride = const bool omit_stride =
@ -1905,16 +1868,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
} }
extra.push_back(StrCat("slice={", Join(bounds, ", "), "}")); extra.push_back(StrCat("slice={", Join(bounds, ", "), "}"));
} }
if (opcode() == HloOpcode::kDynamicSlice) {
extra.push_back(
StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}"));
}
if (opcode() == HloOpcode::kBatchNormTraining ||
opcode() == HloOpcode::kBatchNormInference ||
opcode() == HloOpcode::kBatchNormGrad) {
extra.push_back(StrCat("epsilon=", epsilon()));
extra.push_back(StrCat("feature_index=", feature_index()));
}
if (convolution_dimension_numbers_ != nullptr) { if (convolution_dimension_numbers_ != nullptr) {
extra.push_back(ConvolutionDimensionNumbersToString()); extra.push_back(ConvolutionDimensionNumbersToString());
@ -1938,8 +1891,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
}))); })));
} }
if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv) {
opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) {
extra.push_back(StrCat("channel_id=", channel_id_)); extra.push_back(StrCat("channel_id=", channel_id_));
} }
@ -2119,10 +2071,8 @@ bool HloInstruction::IsFusable() const {
case HloOpcode::kOutfeed: case HloOpcode::kOutfeed:
case HloOpcode::kParameter: case HloOpcode::kParameter:
case HloOpcode::kTrace: case HloOpcode::kTrace:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend: case HloOpcode::kSend:
case HloOpcode::kSendDone: case HloOpcode::kRecv:
return false; return false;
// Only fuse Rng if it is used once, otherwise the random numbers generated // Only fuse Rng if it is used once, otherwise the random numbers generated
// will be different in each fusion. If it is the root (user count = 0) // will be different in each fusion. If it is the root (user count = 0)
@ -2329,14 +2279,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleCall(this); return visitor->HandleCall(this);
case HloOpcode::kCustomCall: case HloOpcode::kCustomCall:
return visitor->HandleCustomCall(this); return visitor->HandleCustomCall(this);
case HloOpcode::kRecv:
return visitor->HandleRecv(this);
case HloOpcode::kRecvDone:
return visitor->HandleRecvDone(this);
case HloOpcode::kSend: case HloOpcode::kSend:
return visitor->HandleSend(this); return visitor->HandleSend(this);
case HloOpcode::kSendDone: case HloOpcode::kRecv:
return visitor->HandleSendDone(this); return visitor->HandleRecv(this);
// These opcodes are not handled here. // These opcodes are not handled here.
case HloOpcode::kTrace: case HloOpcode::kTrace:
@ -2895,21 +2841,6 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str()); return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str());
} }
string PaddingConfigToString(const PaddingConfig& padding) {
bool has_interior_padding =
std::any_of(padding.dimensions().begin(), padding.dimensions().end(),
[](const PaddingConfig::PaddingConfigDimension& dim) {
return dim.interior_padding() != 0;
});
return Join(
padding.dimensions(), "x",
[&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
StrAppend(
out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
});
}
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
return os << ToString(kind); return os << ToString(kind);
} }
@ -2925,7 +2856,13 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const {
const auto append_dims = [&](const std::vector<string>& dims, const auto append_dims = [&](const std::vector<string>& dims,
const Shape& shape) { const Shape& shape) {
CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); CHECK_EQ(dims.size(), ShapeUtil::Rank(shape));
StrAppend(&result, Join(dims, "")); for (int64 logical = 0; logical < dims.size(); ++logical) {
int64 physical = logical;
if (!shape.layout().minor_to_major().empty()) {
physical = LayoutUtil::Major(shape.layout(), logical);
}
result += dims[physical];
}
}; };
// lhs_dims[i] is the symbol of the logical dimension i for the lhs // lhs_dims[i] is the symbol of the logical dimension i for the lhs

View File

@ -181,28 +181,18 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand, const Shape& shape, HloInstruction* operand,
tensorflow::StringPiece outfeed_config); tensorflow::StringPiece outfeed_config);
// Creates an asynchronous send instruction with the given channel id, which // Creates a send instruction with the given channel id, which sends the
// initiates sending the operand data to a unique receive instruction in // operand data to a unique receive instruction in another computation that
// another computation that has the same channel id. // has the same channel id.
static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand, static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
int64 channel_id); int64 channel_id);
// Blocks until data transfer for the Send instruction (operand) is complete. // Creates a receive instruction with the given channel id, which receives
// The operand must be kSend. // data of the given shape from a unique send instruction in another
static std::unique_ptr<HloInstruction> CreateSendDone( // computation that has the same channel id.
HloInstruction* operand);
// Creates an asynchronous receive instruction with the given channel id,
// which allocates resources to receive data of the given shape from a unique
// send instruction in another computation that has the same channel id.
static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape, static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape,
int64 channel_id); int64 channel_id);
// Blocks until data transfer for the Recv instruction (operand) is complete
// and returns the receive buffer. The operand must be kRecv.
static std::unique_ptr<HloInstruction> CreateRecvDone(
HloInstruction* operand);
// Creates a slice instruction, where the operand is sliced by the given // Creates a slice instruction, where the operand is sliced by the given
// start/limit indices. // start/limit indices.
static std::unique_ptr<HloInstruction> CreateSlice( static std::unique_ptr<HloInstruction> CreateSlice(
@ -212,7 +202,7 @@ class HloInstruction {
tensorflow::gtl::ArraySlice<int64> strides); tensorflow::gtl::ArraySlice<int64> strides);
// Creates a slice instruction, where the first operand is sliced by // Creates a slice instruction, where the first operand is sliced by
// start indices specified in the second operand, and by size specfied in // start indices specified in the second operand, and by size specified in
// 'slice_sizes'. // 'slice_sizes'.
static std::unique_ptr<HloInstruction> CreateDynamicSlice( static std::unique_ptr<HloInstruction> CreateDynamicSlice(
const Shape& shape, HloInstruction* operand, const Shape& shape, HloInstruction* operand,
@ -863,11 +853,6 @@ class HloInstruction {
return *window_; return *window_;
} }
// Sets the window data in a windowed operation such as convolution.
void set_window(const Window& window) {
window_ = MakeUnique<Window>(window);
}
// Returns the padding configuration for a pad node. // Returns the padding configuration for a pad node.
// //
// Precondition: opcode() == HloOpcode::kPad // Precondition: opcode() == HloOpcode::kPad
@ -1239,8 +1224,6 @@ string ToString(HloInstruction::FusionKind kind);
StatusOr<HloInstruction::FusionKind> StringToFusionKind( StatusOr<HloInstruction::FusionKind> StringToFusionKind(
const string& kind_name); const string& kind_name);
string PaddingConfigToString(const PaddingConfig& padding);
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
// Map classes that guarantee a deterministic iteration order when the key is // Map classes that guarantee a deterministic iteration order when the key is

View File

@ -792,8 +792,8 @@ TEST_F(HloInstructionTest, ComplexFusionOp) {
// sub = Sub(mul, clamp) // sub = Sub(mul, clamp)
// tuple = Tuple({sub, sub, mul, C1}) // tuple = Tuple({sub, sub, mul, C1})
// //
// Notable complexities are repeated operands in a same instruction, different // Notable complexities are repeated operands in the same instruction,
// shapes, use of value in different expressions. // different shapes, use of value in different expressions.
auto c1 = builder.AddInstruction( auto c1 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
auto c2 = builder.AddInstruction( auto c2 = builder.AddInstruction(

View File

@ -121,7 +121,6 @@ HLO_MATCHER(Outfeed);
HLO_MATCHER(Pad); HLO_MATCHER(Pad);
HLO_MATCHER(Power); HLO_MATCHER(Power);
HLO_MATCHER(Recv); HLO_MATCHER(Recv);
HLO_MATCHER(RecvDone);
HLO_MATCHER(Reduce); HLO_MATCHER(Reduce);
HLO_MATCHER(ReducePrecision); HLO_MATCHER(ReducePrecision);
HLO_MATCHER(ReduceWindow); HLO_MATCHER(ReduceWindow);
@ -132,7 +131,6 @@ HLO_MATCHER(Rng);
HLO_MATCHER(Select); HLO_MATCHER(Select);
HLO_MATCHER(SelectAndScatter); HLO_MATCHER(SelectAndScatter);
HLO_MATCHER(Send); HLO_MATCHER(Send);
HLO_MATCHER(SendDone);
HLO_MATCHER(ShiftLeft); HLO_MATCHER(ShiftLeft);
HLO_MATCHER(ShiftRightLogical); HLO_MATCHER(ShiftRightLogical);
HLO_MATCHER(ShiftRightArithmetic); HLO_MATCHER(ShiftRightArithmetic);

View File

@ -85,11 +85,7 @@ class HloModule {
std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const; std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
// Return a pointer to the entry computation of the module.. // Return a pointer to the entry computation of the module..
const HloComputation* entry_computation() const { HloComputation* entry_computation() const {
CHECK_NE(nullptr, entry_computation_);
return entry_computation_;
}
HloComputation* entry_computation() {
CHECK_NE(nullptr, entry_computation_); CHECK_NE(nullptr, entry_computation_);
return entry_computation_; return entry_computation_;
} }

View File

@ -39,8 +39,8 @@ void HloModuleConfig::SetDefaultComputationLayout(
} }
string HloModuleConfig::compilation_cache_key() const { string HloModuleConfig::compilation_cache_key() const {
string key = string key = tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_,
tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_); "::hybrid=", has_hybrid_result_);
StrAppend(&key, "::("); StrAppend(&key, "::(");
std::vector<string> params; std::vector<string> params;
for (const ShapeLayout& param_layout : for (const ShapeLayout& param_layout :

View File

@ -104,6 +104,16 @@ class HloModuleConfig {
// Whether to enable HLO-level profiling. // Whether to enable HLO-level profiling.
bool hlo_profiling_enabled_ = false; bool hlo_profiling_enabled_ = false;
// If this flag is true, the generated executable will return a ShapedBuffer
// holding the result of the computation. In a ShapedBuffer, tuples have their
// structure held in host memory and the element arrays (leaves of the tuple
// structure) stored in device memory. The ShapedBuffer is considered "hybrid"
// because its leaves are on device but its structure is stored on
// host. Otherwise, if this flag is false, the generated executable will
// return a DeviceMemoryBase where the result is held entirely in device
// memory.
bool has_hybrid_result_ = false;
// Module/graph-level seed handle. // Module/graph-level seed handle.
uint64 seed_ = 0; uint64 seed_ = 0;

View File

@ -97,7 +97,6 @@ namespace xla {
V(kPower, "power") \ V(kPower, "power") \
V(kReal, "real") \ V(kReal, "real") \
V(kRecv, "recv") \ V(kRecv, "recv") \
V(kRecvDone, "recv-done") \
V(kReduce, "reduce") \ V(kReduce, "reduce") \
V(kReducePrecision, "reduce-precision") \ V(kReducePrecision, "reduce-precision") \
V(kReduceWindow, "reduce-window") \ V(kReduceWindow, "reduce-window") \
@ -109,7 +108,6 @@ namespace xla {
V(kSelect, "select") \ V(kSelect, "select") \
V(kSelectAndScatter, "select-and-scatter") \ V(kSelectAndScatter, "select-and-scatter") \
V(kSend, "send") \ V(kSend, "send") \
V(kSendDone, "send-done") \
V(kShiftLeft, "shift-left") \ V(kShiftLeft, "shift-left") \
V(kShiftRightArithmetic, "shift-right-arithmetic") \ V(kShiftRightArithmetic, "shift-right-arithmetic") \
V(kShiftRightLogical, "shift-right-logical") \ V(kShiftRightLogical, "shift-right-logical") \

View File

@ -66,9 +66,7 @@ bool IsRematerializable(const HloInstruction* instruction) {
case HloOpcode::kInfeed: case HloOpcode::kInfeed:
case HloOpcode::kParameter: case HloOpcode::kParameter:
case HloOpcode::kRecv: case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend: case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kTrace: case HloOpcode::kTrace:
case HloOpcode::kWhile: case HloOpcode::kWhile:
return false; return false;

View File

@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/compiler/xla/service/hlo_runner.h" #include "tensorflow/compiler/xla/service/hlo_runner.h"
@ -20,6 +19,8 @@ limitations under the License.
#include <string> #include <string>
#include <utility> #include <utility>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/ptr_util.h"

View File

@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace xla { namespace xla {
@ -39,15 +38,6 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
} }
string HloSharding::ToString() const { string HloSharding::ToString() const {
if (IsTuple()) {
std::vector<string> parts;
parts.reserve(tuple_elements_.size());
for (const HloSharding& element : tuple_elements_) {
parts.push_back(element.ToString());
}
return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}");
}
string result = StrCat("{", (replicated_ ? " replicated" : ""), string result = StrCat("{", (replicated_ ? " replicated" : ""),
(maximal_ ? " maximal" : "")); (maximal_ ? " maximal" : ""));
@ -63,11 +53,6 @@ string HloSharding::ToString() const {
} }
bool HloSharding::UsesDevice(int64 device) const { bool HloSharding::UsesDevice(int64 device) const {
if (IsTuple()) {
return std::any_of(
tuple_elements_.begin(), tuple_elements_.end(),
[&](const HloSharding& s) { return s.UsesDevice(device); });
}
const auto& devices = tile_assignment_; const auto& devices = tile_assignment_;
return replicated_ || return replicated_ ||
std::find(devices.begin(), devices.end(), device) != devices.end(); std::find(devices.begin(), devices.end(), device) != devices.end();
@ -76,7 +61,6 @@ bool HloSharding::UsesDevice(int64 device) const {
std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const { std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!ShapeUtil::IsTuple(tile_shape_));
CHECK(!maximal_); CHECK(!maximal_);
CHECK(!IsTuple());
std::vector<int64> ret_index; std::vector<int64> ret_index;
tile_assignment_.Each([&](tensorflow::gtl::ArraySlice<int64> index, int64 d) { tile_assignment_.Each([&](tensorflow::gtl::ArraySlice<int64> index, int64 d) {
if (d == device) { if (d == device) {
@ -90,7 +74,6 @@ std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
int64 HloSharding::DeviceForTileIndex( int64 HloSharding::DeviceForTileIndex(
tensorflow::gtl::ArraySlice<int64> index) const { tensorflow::gtl::ArraySlice<int64> index) const {
CHECK(!replicated_); CHECK(!replicated_);
CHECK(!IsTuple());
if (maximal_) { if (maximal_) {
return *tile_assignment_.begin(); return *tile_assignment_.begin();
} }
@ -99,7 +82,7 @@ int64 HloSharding::DeviceForTileIndex(
} }
std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const { std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
CHECK(!IsTuple()); CHECK(!ShapeUtil::IsTuple(tile_shape_));
std::vector<int64> index = TileIndexForDevice(device); std::vector<int64> index = TileIndexForDevice(device);
if (maximal_) { if (maximal_) {
@ -114,7 +97,7 @@ std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
} }
std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const { std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
CHECK(!IsTuple()); CHECK(!ShapeUtil::IsTuple(tile_shape_));
CHECK(!maximal_); // Maximal shardings do not have a valid tile shape. CHECK(!maximal_); // Maximal shardings do not have a valid tile shape.
std::vector<int64> index = TileIndexForDevice(device); std::vector<int64> index = TileIndexForDevice(device);
@ -125,41 +108,13 @@ std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
} }
StatusOr<int64> HloSharding::UniqueDevice() const { StatusOr<int64> HloSharding::UniqueDevice() const {
if (IsTuple()) { if (!replicated_ && maximal_) {
if (tuple_elements_.empty()) {
return tensorflow::errors::InvalidArgument(
"UniqueDevice() called on empty tuple");
}
std::vector<StatusOr<int64>> results;
std::transform(tuple_elements_.begin(), tuple_elements_.end(),
std::back_inserter(results),
[](const HloSharding& s) { return s.UniqueDevice(); });
if (std::all_of(results.begin(), results.end(),
[&](const StatusOr<int64>& s) {
return s.ok() && results[0].ok() &&
s.ValueOrDie() == results[0].ValueOrDie();
})) {
return results[0];
} else {
return tensorflow::errors::InvalidArgument(
"Tuple did not contain a unique device");
}
}
if (!replicated_ && maximal_ && !IsTuple()) {
return static_cast<int64>(*tile_assignment_.begin()); return static_cast<int64>(*tile_assignment_.begin());
} }
return tensorflow::errors::InvalidArgument( return tensorflow::errors::InvalidArgument(
"UniqueDevice() called on sharding that executes on multiple devices"); "UniqueDevice() called on sharding that executes on multiple devices");
} }
bool HloSharding::HasUniqueDevice() const {
if (IsTuple()) {
return UniqueDevice().status().ok();
} else {
return !IsReplicated() && IsTileMaximal();
}
}
Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
if (replicated_) { if (replicated_) {
return Status::OK(); return Status::OK();
@ -238,19 +193,9 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
/*static*/ StatusOr<HloSharding> HloSharding::FromProto( /*static*/ StatusOr<HloSharding> HloSharding::FromProto(
const OpSharding& proto) { const OpSharding& proto) {
if (proto.type() == OpSharding::Type::OpSharding_Type_TUPLE) { if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
std::vector<HloSharding> tuple_shardings;
tuple_shardings.reserve(proto.tuple_shardings().size());
for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
TF_ASSIGN_OR_RETURN(HloSharding sharding,
HloSharding::FromProto(tuple_sharding_proto));
tuple_shardings.push_back(sharding);
}
return HloSharding(tuple_shardings);
} else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
return Replicate(); return Replicate();
} else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL || } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) {
proto.tile_assignment_devices().size() == 1) {
return HloSharding(proto.tile_assignment_devices(0)); return HloSharding(proto.tile_assignment_devices(0));
} }
// Some versions of gcc cannot infer the TileAssignment constructor from a // Some versions of gcc cannot infer the TileAssignment constructor from a
@ -267,15 +212,6 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
OpSharding HloSharding::ToProto() const { OpSharding HloSharding::ToProto() const {
OpSharding result; OpSharding result;
if (IsTuple()) {
for (const HloSharding& element : tuple_elements_) {
*result.add_tuple_shardings() = element.ToProto();
}
result.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
return result;
}
*result.mutable_tile_shape() = tile_shape_; *result.mutable_tile_shape() = tile_shape_;
for (int64 dim : tile_assignment_.dimensions()) { for (int64 dim : tile_assignment_.dimensions()) {
result.add_tile_assignment_dimensions(dim); result.add_tile_assignment_dimensions(dim);

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/hash/hash.h"
@ -68,18 +67,6 @@ class HloSharding {
// `num_tiles` tiles. // `num_tiles` tiles.
static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles); static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles);
// Creates a new sharding for a tuple type. The given ShapeTree must have
// elements for every leaf shape contained in the tuple.
static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings) {
std::vector<HloSharding> flattened_list;
flattened_list.reserve(
std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end()));
for (const auto& index_to_sharding : sub_shardings.leaves()) {
flattened_list.push_back(index_to_sharding.second);
}
return HloSharding(flattened_list);
}
// Create a new sharding from a protobuf OpSharding. // Create a new sharding from a protobuf OpSharding.
static StatusOr<HloSharding> FromProto(const OpSharding& proto); static StatusOr<HloSharding> FromProto(const OpSharding& proto);
@ -89,89 +76,47 @@ class HloSharding {
// Validate that this sharding can be applied to a tensor with shape `shape`. // Validate that this sharding can be applied to a tensor with shape `shape`.
Status Validate(const Shape& shape, int64 num_devices) const; Status Validate(const Shape& shape, int64 num_devices) const;
// Returns true if the sharding has tuple type.
bool IsTuple() const { return tuple_; }
// Returns true if the sharding is trivial: replicate on all devices. // Returns true if the sharding is trivial: replicate on all devices.
bool IsReplicated() const { bool IsReplicated() const { return replicated_; }
if (!IsTuple()) {
return replicated_;
}
return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
[](const HloSharding& s) { return s.IsReplicated(); });
}
// Returns true if the tile size is the same as the input size. // Returns true if the tile size is the same as the input size.
bool IsTileMaximal() const { bool IsTileMaximal() const { return maximal_; }
if (!IsTuple()) {
return maximal_;
}
return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
[](const HloSharding& s) { return s.IsTileMaximal(); });
}
// Returns true if the sharding defines an operation on the given device. // Returns true if the sharding defines an operation on the given device.
bool UsesDevice(int64 device) const; bool UsesDevice(int64 device) const;
// Returns the tile that should be executed on the given device. // Returns the tile that should be executed on the given device.
// REQUIRES: !IsTuple()
std::vector<int64> TileIndexForDevice(int64 device) const; std::vector<int64> TileIndexForDevice(int64 device) const;
// Returns the device that should execute the given tile. // Returns the device that should execute the given tile.
// It is an error to call this if is_replicated() is true. // It is an error to call this if is_replicated() is true.
// REQUIRES: !IsTuple()
int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const; int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const;
// Given a device ID, returns the offset within the input space of the // Given a device ID, returns the offset within the input space of the
// tile that should be executed on the given core. This returns the lower // tile that should be executed on the given core. This returns the lower
// extent of the tile in the input space. // extent of the tile in the input space.
// REQUIRES: !IsTuple()
std::vector<int64> TileOffsetForDevice(int64 device) const; std::vector<int64> TileOffsetForDevice(int64 device) const;
// Given a device ID, returns the limit within the input space of the // Given a device ID, returns the limit within the input space of the
// tile that should be executed on the given core. This returns the upper // tile that should be executed on the given core. This returns the upper
// extent of the tile in the input space. // extent of the tile in the input space.
// REQUIRES: !IsTuple()
std::vector<int64> TileLimitForDevice(int64 device) const; std::vector<int64> TileLimitForDevice(int64 device) const;
// Returns the single device this op operates on. // Returns the single device this op operates on.
// REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal() // Requires !Replicated() && IsTileMaximal().
StatusOr<int64> UniqueDevice() const; StatusOr<int64> UniqueDevice() const;
// Returns true if this op only uses a single device. // Returns true if this op only uses a single device.
bool HasUniqueDevice() const; bool HasUniqueDevice() const { return !IsReplicated() && IsTileMaximal(); }
// Returns the ShapeTree containing the shardings for each element of this
// tuple. Only the leaf elements are populated. This creates a new ShapeTree
// object so is not cheap. REQUIRES: IsTuple()
ShapeTree<HloSharding> GetTupleShardingsAsShapeTree(
const Shape& tuple_shape) const {
ShapeTree<HloSharding> result(tuple_shape, HloSharding::Replicate());
CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()),
tuple_elements_.size());
auto it = tuple_elements_.begin();
for (auto& index_to_sharding : result.leaves()) {
index_to_sharding.second = *it++;
}
return result;
}
bool operator==(const HloSharding& other) const { bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ && return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) && protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) &&
tile_assignment_ == other.tile_assignment_ && tile_assignment_ == other.tile_assignment_;
tuple_elements_ == other.tuple_elements_;
} }
bool operator!=(const HloSharding& other) const { return !(*this == other); } bool operator!=(const HloSharding& other) const { return !(*this == other); }
size_t Hash() const { size_t Hash() const {
if (!tuple_) {
size_t h = 0;
for (const auto& element : tuple_elements_) {
h = tensorflow::Hash64Combine(h, element.Hash());
}
return h;
}
if (replicated_) { if (replicated_) {
return 0; return 0;
} }
@ -186,47 +131,33 @@ class HloSharding {
} }
// Gets the tile shape. // Gets the tile shape.
// REQUIRES: !IsTileMaximal() && !IsTuple() // It is an error to call this if IsTileMaximal() is true.
const Shape& tile_shape() const { return tile_shape_; } const Shape& tile_shape() const { return tile_shape_; }
// Gets the tile assignment tensor. // Gets the tile assignment tensor.
// REQUIRES: !IsReplicated() && !IsTuple() // It is an error to call this if IsReplicated() is true.
const Array<int64>& tile_assignment() const { return tile_assignment_; } const Array<int64>& tile_assignment() const { return tile_assignment_; }
private: private:
HloSharding() HloSharding()
: replicated_(true), : replicated_(true),
maximal_(true), maximal_(true),
tuple_(false),
tile_shape_(), tile_shape_(),
tile_assignment_({0}) {} tile_assignment_({0}) {}
explicit HloSharding(int64 device_id) explicit HloSharding(int64 device_id)
: replicated_(false), : replicated_(false),
maximal_(true), maximal_(true),
tuple_(false),
tile_shape_(), tile_shape_(),
tile_assignment_({1}, device_id) {} tile_assignment_({1}, device_id) {}
HloSharding(const Shape& tile_shape, const Array<int64>& tile_assignment) HloSharding(const Shape& tile_shape, const Array<int64>& tile_assignment)
: replicated_(false), : replicated_(false),
maximal_(false), maximal_(false),
tuple_(false),
tile_shape_(tile_shape), tile_shape_(tile_shape),
tile_assignment_(tile_assignment) {} tile_assignment_(tile_assignment) {}
HloSharding(const std::vector<HloSharding>& tuple_shardings)
: replicated_(false),
maximal_(false),
tuple_(true),
tile_assignment_({0}),
tuple_elements_(tuple_shardings) {}
bool replicated_; bool replicated_;
bool maximal_; bool maximal_;
bool tuple_;
Shape tile_shape_; Shape tile_shape_;
Array<int64> tile_assignment_; Array<int64> tile_assignment_;
// Only non-empty when tuple_ is true, but because empty tuples are allowed
// may also be empty even then. This is a flattened list of all the leaf
// shardings in a tuple shape, by pre-order walk (ShapeTree iterator order).
std::vector<HloSharding> tuple_elements_;
}; };
} // namespace xla } // namespace xla

View File

@ -132,29 +132,6 @@ TEST_F(HloShardingTest, Tile) {
} }
} }
TEST_F(HloShardingTest, NestedTuple) {
// nested_tuple_shape = (f32[], (f32[3]), f32[4, 6])
Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {}),
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}),
ShapeUtil::MakeShape(F32, {4, 6}),
});
OpSharding proto;
proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
*proto.add_tuple_shardings() = HloSharding::Replicate().ToProto();
*proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto();
*proto.add_tuple_shardings() = HloSharding::AssignDevice(1).ToProto();
HloSharding tuple_sharding =
HloSharding::FromProto(proto).ConsumeValueOrDie();
ShapeTree<HloSharding> shape_tree =
tuple_sharding.GetTupleShardingsAsShapeTree(nested_tuple_shape);
EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate());
EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0));
EXPECT_EQ(shape_tree.element({2}), HloSharding::AssignDevice(1));
}
TEST_F(HloShardingTest, Hash) { TEST_F(HloShardingTest, Hash) {
auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) { auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) {
if (a.Hash() != b.Hash()) { if (a.Hash() != b.Hash()) {
@ -207,51 +184,6 @@ TEST_F(HloShardingTest, Hash) {
MakeArray({2, 2}, {0, 3, 1, 2})); MakeArray({2, 2}, {0, 3, 1, 2}));
EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
} }
HloSharding default_sharding = HloSharding::Replicate();
{
ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
default_sharding);
HloSharding sharding1 = HloSharding::Replicate();
HloSharding sharding2 = HloSharding::Tuple(shape_tree);
EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
}
{
ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
default_sharding);
HloSharding sharding1 = HloSharding::Tuple(shape_tree);
HloSharding sharding2 = HloSharding::Tuple(shape_tree);
EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
}
{
ShapeTree<HloSharding> shape_tree1(
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
default_sharding);
*shape_tree1.mutable_element({0}) = HloSharding::Replicate();
ShapeTree<HloSharding> shape_tree2(
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
default_sharding);
*shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
}
{
ShapeTree<HloSharding> shape_tree1(
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
default_sharding);
*shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0);
ShapeTree<HloSharding> shape_tree2(
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
default_sharding);
*shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
}
} }
} // namespace } // namespace

View File

@ -270,40 +270,12 @@ class ShapeVerifier : public DfsHloVisitor {
pad->padding_config())); pad->padding_config()));
} }
Status HandleSend(HloInstruction* send) override { Status HandleSend(HloInstruction*) override {
TF_RET_CHECK(send->users().size() == 1); return tensorflow::Status::OK();
const HloInstruction* send_done = send->users()[0];
TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
return CheckShape(
send, ShapeUtil::MakeTupleShape(
{send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})}));
} }
Status HandleSendDone(HloInstruction* send_done) override { Status HandleRecv(HloInstruction*) override {
TF_RET_CHECK(send_done->operands().size() == 1); return tensorflow::Status::OK();
const HloInstruction* send = send_done->operand(0);
TF_RET_CHECK(send->opcode() == HloOpcode::kSend);
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
return CheckShape(send_done, ShapeUtil::MakeNil());
}
Status HandleRecv(HloInstruction* recv) override {
TF_RET_CHECK(recv->users().size() == 1);
const HloInstruction* recv_done = recv->users()[0];
TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
return CheckShape(recv,
ShapeUtil::MakeTupleShape(
{recv_done->shape(), ShapeUtil::MakeShape(U32, {})}));
}
Status HandleRecvDone(HloInstruction* recv_done) override {
TF_RET_CHECK(recv_done->operands().size() == 1);
const HloInstruction* recv = recv_done->operand(0);
TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv);
TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
return CheckShape(recv_done, recv->shape().tuple_shapes(0));
} }
Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override {
@ -393,19 +365,6 @@ class ShapeVerifier : public DfsHloVisitor {
instruction->opcode(), instruction->operands())); instruction->opcode(), instruction->operands()));
} }
// Checks if the given two instructions shares the same channel id.
Status CheckSameChannel(const HloInstruction* instr1,
const HloInstruction* instr2) {
if (instr1->channel_id() != instr2->channel_id()) {
return FailedPrecondition(
"Expected to have the same channel id, actual channel ids are: %s "
"(%lld), %s (%lld)",
instr1->ToString().c_str(), instr1->channel_id(),
instr2->ToString().c_str(), instr2->channel_id());
}
return tensorflow::Status::OK();
}
// Returns the size of a Shape in bytes. // Returns the size of a Shape in bytes.
const std::function<int64(const Shape&)> shape_size_fn_; const std::function<int64(const Shape&)> shape_size_fn_;
}; };

View File

@ -113,9 +113,7 @@ namespace xla {
case HloOpcode::kTrace: case HloOpcode::kTrace:
case HloOpcode::kWhile: case HloOpcode::kWhile:
case HloOpcode::kSend: case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv: case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
return true; return true;
} }

View File

@ -89,7 +89,7 @@ StatusOr<se::DeviceMemoryBase> InterpreterExecutable::ExecuteOnStream(
uint64 start_micros = tensorflow::Env::Default()->NowMicros(); uint64 start_micros = tensorflow::Env::Default()->NowMicros();
const HloComputation* computation = module().entry_computation(); HloComputation* computation = module().entry_computation();
if (computation->num_parameters() != arguments.size()) { if (computation->num_parameters() != arguments.size()) {
return tensorflow::errors::Internal( return tensorflow::errors::Internal(
"Mismatch between argument count and graph parameter count."); "Mismatch between argument count and graph parameter count.");

View File

@ -131,10 +131,10 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}}; std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
for (auto& minor_to_major : minor_to_majors) { for (auto& minor_to_major : minor_to_majors) {
auto builder = HloComputation::Builder(TestName()); auto builder = HloComputation::Builder(TestName());
auto constant_literal1 = Literal::CreateR2WithLayout<float>( auto constant_literal1 = test_utils::CreateR2LiteralWithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); {{1.0, 2.0}, {3.0, 4.0}}, minor_to_major);
auto constant_literal2 = Literal::CreateR2WithLayout<float>( auto constant_literal2 = test_utils::CreateR2LiteralWithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); {{5.0, 6.0}, {7.0, 8.0}}, minor_to_major);
Shape ashape = constant_literal1->shape(); Shape ashape = constant_literal1->shape();
auto constant1 = builder.AddInstruction( auto constant1 = builder.AddInstruction(
@ -181,12 +181,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
// Verify the layouts of a tuple are assigned properly (the element layouts // Verify the layouts of a tuple are assigned properly (the element layouts
// match their source). // match their source).
auto builder = HloComputation::Builder(TestName()); auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction( auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); {0, 1})));
auto constant1 = builder.AddInstruction( auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); {1, 0})));
auto tuple = builder.AddInstruction( auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1})); HloInstruction::CreateTuple({constant0, constant1}));
@ -218,12 +218,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
TEST_F(LayoutAssignmentTest, TupleSelect) { TEST_F(LayoutAssignmentTest, TupleSelect) {
// Verify layouts of a select with tuple operands is assigned properly. // Verify layouts of a select with tuple operands is assigned properly.
auto builder = HloComputation::Builder(TestName()); auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction( auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); {0, 1})));
auto constant1 = builder.AddInstruction( auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); {1, 0})));
auto tuple0 = builder.AddInstruction( auto tuple0 = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1})); HloInstruction::CreateTuple({constant0, constant1}));
auto tuple1 = builder.AddInstruction( auto tuple1 = builder.AddInstruction(

View File

@ -155,30 +155,6 @@ cc_library(
], ],
) )
cc_library(
name = "vector_support_library",
srcs = ["vector_support_library.cc"],
hdrs = ["vector_support_library.h"],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"@llvm//:core",
],
)
cc_library(
name = "kernel_support_library",
srcs = ["kernel_support_library.cc"],
hdrs = ["kernel_support_library.h"],
deps = [
":llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"@llvm//:core",
],
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
filegroup( filegroup(

View File

@ -1,65 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
namespace xla {
void KernelSupportLibrary::For(
tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<void(llvm::Value*, bool)>& for_body_generator) {
If(ir_builder_->CreateICmpSLT(start, end), [&]() {
for_body_generator(start, /*is_first_iteration=*/true);
For(name, ir_builder_->CreateAdd(start, step), end, step,
[&](llvm::Value* iv) { for_body_generator(iv, false); });
});
}
void KernelSupportLibrary::For(
tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
llvm::Value* step, bool peel_first_iteration,
const std::function<void(llvm::Value*, llvm::Value*)>& for_body_generator) {
if (peel_first_iteration) {
For(name, start, end, step, true,
[&](llvm::Value* indvar, bool is_first_iteration) {
for_body_generator(indvar, ir_builder_->getInt1(is_first_iteration));
});
} else {
std::unique_ptr<llvm_ir::ForLoop> loop = llvm_ir::ForLoop::EmitForLoop(
name, start, end, step, ir_builder_,
/*prevent_unrolling=*/prevent_unrolling_,
/*prevent_vectorization=*/prevent_vectorization_);
ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back());
for_body_generator(loop->GetIndVarValue(),
/*is_first_iteration=*/ir_builder_->CreateICmpEQ(
loop->GetIndVarValue(), start));
llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_);
}
}
void KernelSupportLibrary::If(
llvm::Value* condition, const std::function<void()>& true_block_generator,
const std::function<void()>& false_block_generator) {
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(condition, "", ir_builder_);
ir_builder_->SetInsertPoint(&if_data.true_block->back());
true_block_generator();
ir_builder_->SetInsertPoint(&if_data.false_block->back());
false_block_generator();
llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_);
}
} // namespace xla

View File

@ -1,128 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
#include <string>
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
// A thin wrapper around llvm_loop.h to make code generating structured control
// flow more readable.
class KernelSupportLibrary {
public:
// `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR.
// If `prevent_unrolling` is true then unrolling is explicitly disabled on
// every loop generated by this instance of KernelSupportLibrary.
explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder,
bool prevent_unrolling = true,
bool prevent_vectorization = true)
: ir_builder_(ir_builder),
prevent_unrolling_(prevent_unrolling),
prevent_vectorization_(prevent_vectorization) {}
// Generates the following control flow structure:
//
// if (`start` < `end`) {
// `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/true)`;
// for (i64 i = `start` + `step`; i s< `end`; i += `step`)
// `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`;
// }
void For(
tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
for_body_generator);
void For(
tensorflow::StringPiece name, int64 start, int64 end, int64 step,
const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
for_body_generator) {
For(name, /*start=*/ir_builder_->getInt64(start),
/*end=*/ir_builder_->getInt64(end),
/*step=*/ir_builder_->getInt64(step), for_body_generator);
}
// Generates the following control flow structure if `peel_first_iteration` is
// true:
//
// if (`start` < `end`) {
// `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/,true)`;
// for (i64 i = `start` + `step`; i s< `end`; i += `step`)
// `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/,false)`;
// }
//
// and the following if `peel_first_iteration` is false:
//
// for (i64 i = `start`; i s< `end`; i += `step`)
// `for_body_generator(/*ind_var=*/,i,
// /*is_first_iteration=*/,(i != `start`))`;
void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
llvm::Value* step, bool peel_first_iteration,
const std::function<void(llvm::Value* ind_var,
llvm::Value* is_first_iteration)>&
for_body_generator);
void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
int64 step, bool peel_first_iteration,
const std::function<void(llvm::Value* ind_var,
llvm::Value* is_first_iteration)>&
for_body_generator) {
For(name, /*start=*/start, /*end=*/end,
/*step=*/ir_builder_->getInt64(step), peel_first_iteration,
for_body_generator);
}
void For(
tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
For(name, start, end, step,
/*peel_first_iteration=*/false,
[&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); });
}
void For(
tensorflow::StringPiece name, int64 start, int64 end, int64 step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
For(name, /*start=*/ir_builder_->getInt64(start),
/*end=*/ir_builder_->getInt64(end),
/*step=*/ir_builder_->getInt64(step), for_body_generator);
}
// Generates the following control flow structure:
//
// if (`condition`)
// `true_block_generator()`;
// else
// `false_block_generator()`;
void If(llvm::Value* condition,
const std::function<void()>& true_block_generator,
const std::function<void()>& false_block_generator = []() {});
private:
llvm::IRBuilder<>* ir_builder_;
bool prevent_unrolling_;
bool prevent_vectorization_;
};
} // namespace xla
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_

View File

@ -34,24 +34,21 @@ namespace llvm_ir {
ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
llvm::Value* start_index, llvm::Value* end_index, llvm::Value* start_index, llvm::Value* end_index,
llvm::Value* step, bool prevent_unrolling, llvm::Value* step, bool prevent_unrolling)
bool prevent_vectorization)
: prefix_(prefix.ToString()), : prefix_(prefix.ToString()),
suffix_(suffix.ToString()), suffix_(suffix.ToString()),
start_index_(start_index), start_index_(start_index),
end_index_(end_index), end_index_(end_index),
step_(step), step_(step),
insert_before_bb_(nullptr), insert_before_bb_(nullptr),
prevent_unrolling_(prevent_unrolling), prevent_unrolling_(prevent_unrolling) {}
prevent_vectorization_(prevent_vectorization) {}
/* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop( /* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
tensorflow::StringPiece prefix, llvm::Value* start_index, tensorflow::StringPiece prefix, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder,
bool prevent_unrolling, bool prevent_vectorization) { bool prevent_unrolling) {
std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index, std::unique_ptr<ForLoop> loop(new ForLoop(
end_index, step, prevent_unrolling, prefix, /*suffix=*/"", start_index, end_index, step, prevent_unrolling));
prevent_vectorization));
loop->Emit(ir_builder); loop->Emit(ir_builder);
return loop; return loop;
} }
@ -130,12 +127,14 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) {
ir_builder->CreateStore(indvar_inc, indvar_address); ir_builder->CreateStore(indvar_inc, indvar_address);
llvm::BranchInst* back_branch = ir_builder->CreateBr(header_bb_); llvm::BranchInst* back_branch = ir_builder->CreateBr(header_bb_);
std::vector<llvm::Metadata*> loop_metadata = GetLoopMetadata(ir_builder); if (prevent_unrolling_) {
if (!loop_metadata.empty()) { const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
llvm::LLVMContext* ctx = &start_index_->getContext(); llvm::LLVMContext* ctx = &back_branch->getContext();
auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None); auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None);
loop_metadata.insert(loop_metadata.begin(), temp_node.get()); auto no_unroll_node = llvm::MDNode::get(
auto loop_id = llvm::MDNode::get(*ctx, loop_metadata); *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)});
auto loop_id = llvm::MDNode::get(*ctx, {temp_node.get(), no_unroll_node});
loop_id->replaceOperandWith(0, loop_id); loop_id->replaceOperandWith(0, loop_id);
back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id); back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id);
} }
@ -144,27 +143,6 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) {
ir_builder->SetInsertPoint(exit_bb_); ir_builder->SetInsertPoint(exit_bb_);
} }
std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(
llvm::IRBuilder<>* ir_builder) {
const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable";
llvm::LLVMContext* ctx = &start_index_->getContext();
std::vector<llvm::Metadata*> result;
if (prevent_unrolling_) {
result.push_back(llvm::MDNode::get(
*ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}));
}
if (prevent_vectorization_) {
result.push_back(llvm::MDNode::get(
*ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName),
llvm::ConstantAsMetadata::get(ir_builder->getFalse())}));
}
return result;
}
string ForLoop::GetQualifiedName(tensorflow::StringPiece name) { string ForLoop::GetQualifiedName(tensorflow::StringPiece name) {
return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_)); return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
} }
@ -178,25 +156,23 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
llvm::Value* start_index, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* end_index,
bool prevent_unrolling, bool prevent_unrolling) {
bool prevent_vectorization) {
return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1), return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1),
prevent_unrolling, prevent_vectorization); prevent_unrolling);
} }
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
llvm::Value* start_index, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* end_index,
llvm::Value* stride, llvm::Value* stride,
bool prevent_unrolling, bool prevent_unrolling) {
bool prevent_vectorization) {
if (inner_loop_body_bb_ != nullptr) { if (inner_loop_body_bb_ != nullptr) {
// Create this loop inside the previous one. // Create this loop inside the previous one.
ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
} }
std::unique_ptr<ForLoop> loop(new ForLoop( std::unique_ptr<ForLoop> loop(new ForLoop(
/*prefix=*/name_, suffix, start_index, end_index, stride, /*prefix=*/name_, suffix, start_index, end_index, stride,
prevent_unrolling, prevent_vectorization)); prevent_unrolling));
loop->Emit(ir_builder_); loop->Emit(ir_builder_);
if (outer_loop_preheader_bb_ == nullptr) { if (outer_loop_preheader_bb_ == nullptr) {
@ -215,24 +191,20 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index, std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index, int64 end_index,
tensorflow::StringPiece suffix, tensorflow::StringPiece suffix,
bool prevent_unrolling, bool prevent_unrolling) {
bool prevent_vectorization) {
CHECK_LE(start_index, end_index); CHECK_LE(start_index, end_index);
return AddLoop(suffix, ir_builder_->getInt64(start_index), return AddLoop(suffix, ir_builder_->getInt64(start_index),
ir_builder_->getInt64(end_index), prevent_unrolling, ir_builder_->getInt64(end_index), prevent_unrolling);
prevent_vectorization);
} }
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index, std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index, int64 stride, int64 end_index, int64 stride,
tensorflow::StringPiece suffix, tensorflow::StringPiece suffix,
bool prevent_unrolling, bool prevent_unrolling) {
bool prevent_vectorization) {
CHECK_LE(start_index, end_index); CHECK_LE(start_index, end_index);
return AddLoop(suffix, ir_builder_->getInt64(start_index), return AddLoop(suffix, ir_builder_->getInt64(start_index),
ir_builder_->getInt64(end_index), ir_builder_->getInt64(end_index),
ir_builder_->getInt64(stride), prevent_unrolling, ir_builder_->getInt64(stride), prevent_unrolling);
prevent_vectorization);
} }
IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,

View File

@ -71,10 +71,12 @@ class ForLoop {
// //
// If `prevent_unrolling` is true then emit metadata that directs LLVM to not // If `prevent_unrolling` is true then emit metadata that directs LLVM to not
// unroll the generated loop. // unroll the generated loop.
static std::unique_ptr<ForLoop> EmitForLoop( static std::unique_ptr<ForLoop> EmitForLoop(tensorflow::StringPiece prefix,
tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, llvm::Value* end_index,
bool prevent_unrolling = false, bool prevent_vectorization = false); llvm::Value* step,
llvm::IRBuilder<>* ir_builder,
bool prevent_unrolling = false);
// The names of the blocks follow LLVM's conventions. Control flow amongst the // The names of the blocks follow LLVM's conventions. Control flow amongst the
// blocks for the example C code looks like: // blocks for the example C code looks like:
@ -128,7 +130,7 @@ class ForLoop {
ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step,
bool prevent_unrolling, bool prevent_vectorization); bool prevent_unrolling);
// Emit the loop at the insert point of the builder. // Emit the loop at the insert point of the builder.
void Emit(llvm::IRBuilder<>* ir_builder); void Emit(llvm::IRBuilder<>* ir_builder);
@ -140,10 +142,6 @@ class ForLoop {
// they are set. // they are set.
string GetQualifiedName(tensorflow::StringPiece name); string GetQualifiedName(tensorflow::StringPiece name);
// Return a list of metadata nodes that should be associated with the
// llvm::Loop for this `ForLoop`.
std::vector<llvm::Metadata*> GetLoopMetadata(llvm::IRBuilder<>* ir_builder);
string prefix_; string prefix_;
string suffix_; string suffix_;
llvm::Value* start_index_; llvm::Value* start_index_;
@ -162,7 +160,6 @@ class ForLoop {
llvm::BasicBlock* exit_bb_; llvm::BasicBlock* exit_bb_;
llvm::Value* indvar_; llvm::Value* indvar_;
bool prevent_unrolling_; bool prevent_unrolling_;
bool prevent_vectorization_;
TF_DISALLOW_COPY_AND_ASSIGN(ForLoop); TF_DISALLOW_COPY_AND_ASSIGN(ForLoop);
}; };
@ -188,28 +185,24 @@ class ForLoopNest {
std::unique_ptr<ForLoop> AddLoop(tensorflow::StringPiece suffix, std::unique_ptr<ForLoop> AddLoop(tensorflow::StringPiece suffix,
llvm::Value* start_index, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* stride, llvm::Value* end_index, llvm::Value* stride,
bool prevent_unrolling = false, bool prevent_unrolling = false);
bool prevent_vectorization = false);
// Like the above, except that it defaults to a stride of one. // Like the above, except that it defaults to a stride of one.
std::unique_ptr<ForLoop> AddLoop(tensorflow::StringPiece suffix, std::unique_ptr<ForLoop> AddLoop(tensorflow::StringPiece suffix,
llvm::Value* start_index, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* end_index,
bool prevent_unrolling = false, bool prevent_unrolling = false);
bool prevent_vectorization = false);
// A convenient wrapper of the other flavor of AddLoop. The given start and // A convenient wrapper of the other flavor of AddLoop. The given start and
// end index are constant. // end index are constant.
std::unique_ptr<ForLoop> AddLoop(int64 start_index, int64 end_index, std::unique_ptr<ForLoop> AddLoop(int64 start_index, int64 end_index,
int64 stride, tensorflow::StringPiece suffix, int64 stride, tensorflow::StringPiece suffix,
bool prevent_unrolling = false, bool prevent_unrolling = false);
bool prevent_vectorization = false);
// Like the above, except that it defaults to a stride of one. // Like the above, except that it defaults to a stride of one.
std::unique_ptr<ForLoop> AddLoop(int64 start_index, int64 end_index, std::unique_ptr<ForLoop> AddLoop(int64 start_index, int64 end_index,
tensorflow::StringPiece suffix, tensorflow::StringPiece suffix,
bool prevent_unrolling = false, bool prevent_unrolling = false);
bool prevent_vectorization = false);
// Add loops to iterate through the indices within the specified // Add loops to iterate through the indices within the specified
// shape. The returned index collects the induction variables of the // shape. The returned index collects the induction variables of the

View File

@ -537,14 +537,6 @@ void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
builder->SetInsertPoint(blk, blk->getFirstInsertionPt()); builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
} }
void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
if (llvm::Instruction* terminator = blk->getTerminator()) {
builder->SetInsertPoint(terminator);
} else {
builder->SetInsertPoint(blk);
}
}
llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor, llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
llvm::IRBuilder<>* builder) { llvm::IRBuilder<>* builder) {
auto size = rotand->getType()->getPrimitiveSizeInBits(); auto size = rotand->getType()->getPrimitiveSizeInBits();

View File

@ -243,8 +243,6 @@ llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder); void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
// Create a bitwise rotation of `rotand` by `rotor`. // Create a bitwise rotation of `rotand` by `rotor`.
llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor, llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
llvm::IRBuilder<>* builder); llvm::IRBuilder<>* builder);

View File

@ -1,150 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
namespace xla {
VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type,
int64 vector_size,
llvm::IRBuilder<>* ir_builder,
std::string name)
: vector_size_(vector_size),
primitive_type_(primitive_type),
ir_builder_(ir_builder),
name_(std::move(name)) {
scalar_type_ = llvm_ir::PrimitiveTypeToIrType(
primitive_type, ir_builder_->GetInsertBlock()->getModule());
scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_);
vector_type_ = llvm::VectorType::get(scalar_type_, vector_size);
vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_);
}
llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) {
if (scalar_type_->isFloatingPointTy()) {
return ir_builder()->CreateFMul(lhs, rhs, name());
} else {
return ir_builder()->CreateMul(lhs, rhs, name());
}
}
llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) {
if (scalar_type_->isFloatingPointTy()) {
return ir_builder()->CreateFAdd(lhs, rhs, name());
} else {
return ir_builder()->CreateAdd(lhs, rhs, name());
}
}
llvm::Value* VectorSupportLibrary::ComputeOffsetPointer(
llvm::Value* base_pointer, llvm::Value* offset_elements) {
if (base_pointer->getType() != scalar_pointer_type()) {
base_pointer = ir_builder()->CreateBitCast(base_pointer,
scalar_pointer_type(), name());
}
return ir_builder()->CreateInBoundsGEP(base_pointer, {offset_elements},
name());
}
llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) {
if (pointer->getType() != vector_pointer_type()) {
pointer =
ir_builder()->CreateBitCast(pointer, vector_pointer_type(), name());
}
return ir_builder()->CreateAlignedLoad(
pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
}
llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) {
if (pointer->getType() != scalar_pointer_type()) {
pointer =
ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
}
return ir_builder()->CreateAlignedLoad(
pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
}
void VectorSupportLibrary::StoreVector(llvm::Value* value,
llvm::Value* pointer) {
if (pointer->getType() != vector_pointer_type()) {
pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type());
}
ir_builder()->CreateAlignedStore(
value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
}
void VectorSupportLibrary::StoreScalar(llvm::Value* value,
llvm::Value* pointer) {
if (pointer->getType() != scalar_pointer_type()) {
pointer =
ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
}
ir_builder()->CreateAlignedStore(
value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
}
llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) {
if (pointer->getType() != scalar_pointer_type()) {
pointer =
ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
}
return ir_builder()->CreateVectorSplat(
vector_size(), ir_builder()->CreateLoad(pointer), name());
}
llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) {
llvm::SmallVector<llvm::Constant*, 32> mask(vector_size(), nullptr);
for (unsigned i = vector_size(); i != 1; i >>= 1) {
// On every iteration, we shuffle half of the remaining lanes to the top
// half of shuffle, and add two old and the new vector.
for (unsigned j = 0; j < vector_size(); ++j) {
if (j < (i / 2)) {
mask[j] = ir_builder()->getInt32(i / 2 + j);
} else {
mask[j] = llvm::UndefValue::get(ir_builder()->getInt32Ty());
}
}
llvm::Value* half_remaining_lanes = ir_builder()->CreateShuffleVector(
vector, llvm::UndefValue::get(vector_type()),
llvm::ConstantVector::get(mask), "");
vector = Add(vector, half_remaining_lanes);
}
return ir_builder()->CreateExtractElement(vector, ir_builder()->getInt32(0),
name());
}
llvm::Value* VectorSupportLibrary::GetZeroVector() {
return llvm::Constant::getNullValue(vector_type());
}
llvm::Value* VectorSupportLibrary::GetZeroScalar() {
return llvm::Constant::getNullValue(scalar_type());
}
LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder)
: ir_builder_(ir_builder) {
alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_);
}
llvm::Value* LlvmVariable::Get() { return ir_builder_->CreateLoad(alloca_); }
void LlvmVariable::Set(llvm::Value* new_value) {
ir_builder_->CreateStore(new_value, alloca_);
}
} // namespace xla

View File

@ -1,174 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_
#include <string>
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
// A thin wrapper around llvm_util.h to make code generating vector math flow
// more readable.
class VectorSupportLibrary {
public:
// This VectorSupportLibrary instance remembers `primitive_type` and
// `vector_size`, and these are implicitly used by the methods on this
// instance (i.e. LoadVector will load a vector of type <`vector_size` x
// `primitive_type`>).
VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size,
llvm::IRBuilder<>* ir_builder, std::string name);
llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
return Mul(ir_builder()->getInt64(lhs), rhs);
}
llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
return Add(ir_builder()->getInt64(lhs), rhs);
}
llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
return Add(c, Mul(a, b));
}
llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
llvm::Value* offset_elements);
llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
int64 offset_elements) {
return ComputeOffsetPointer(base_pointer,
ir_builder()->getInt64(offset_elements));
}
llvm::Value* LoadVector(llvm::Value* pointer);
llvm::Value* LoadVector(llvm::Value* base_pointer,
llvm::Value* offset_elements) {
return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements));
}
llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) {
return LoadVector(base_pointer, ir_builder()->getInt64(offset_elements));
}
llvm::Value* LoadScalar(llvm::Value* pointer);
llvm::Value* LoadScalar(llvm::Value* base_pointer,
llvm::Value* offset_elements) {
return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements));
}
llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) {
return LoadScalar(base_pointer, ir_builder()->getInt64(offset_elements));
}
void StoreVector(llvm::Value* value, llvm::Value* pointer);
void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
llvm::Value* offset_elements) {
StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements));
}
void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
int64 offset_elements) {
StoreVector(value, base_pointer, ir_builder()->getInt64(offset_elements));
}
void StoreScalar(llvm::Value* value, llvm::Value* pointer);
void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
llvm::Value* offset_elements) {
StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements));
}
void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
int64 offset_elements) {
StoreScalar(base_pointer, ir_builder()->getInt64(offset_elements));
}
llvm::Value* LoadBroadcast(llvm::Value* pointer);
llvm::Value* LoadBroadcast(llvm::Value* base_pointer,
llvm::Value* offset_elements) {
return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements));
}
llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) {
return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements));
}
llvm::Value* AddReduce(llvm::Value* vector);
llvm::Value* GetZeroVector();
llvm::Value* GetZeroScalar();
llvm::IRBuilder<>* ir_builder() const { return ir_builder_; }
int64 vector_size() const { return vector_size_; }
llvm::Type* vector_type() const { return vector_type_; }
llvm::Type* vector_pointer_type() const { return vector_pointer_type_; }
llvm::Type* scalar_type() const { return scalar_type_; }
llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; }
const std::string& name() const { return name_; }
private:
int64 vector_size_;
PrimitiveType primitive_type_;
llvm::IRBuilder<>* ir_builder_;
llvm::Type* vector_type_;
llvm::Type* vector_pointer_type_;
llvm::Type* scalar_type_;
llvm::Type* scalar_pointer_type_;
std::string name_;
};
// This wraps an alloca-backed stack variable which LLVM's SSA construction pass
// can later convert to a SSA value.
class LlvmVariable {
public:
LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder);
llvm::Value* Get();
void Set(llvm::Value* new_value);
private:
llvm::AllocaInst* alloca_;
llvm::IRBuilder<>* ir_builder_;
};
class VectorVariable : public LlvmVariable {
public:
VectorVariable(VectorSupportLibrary* vector_support,
llvm::Value* initial_value)
: LlvmVariable(vector_support->vector_type(),
vector_support->ir_builder()) {
Set(initial_value);
}
};
class ScalarVariable : public LlvmVariable {
public:
ScalarVariable(VectorSupportLibrary* vector_support,
llvm::Value* initial_value)
: LlvmVariable(vector_support->scalar_type(),
vector_support->ir_builder()) {
Set(initial_value);
}
};
} // namespace xla
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_

View File

@ -68,6 +68,26 @@ LocalService::LocalService(const ServiceOptions& options,
std::unique_ptr<Backend> execute_backend) std::unique_ptr<Backend> execute_backend)
: Service(options, std::move(execute_backend)) {} : Service(options, std::move(execute_backend)) {}
namespace {
// Returns the space required to allocate a shape. If
// allocate_space_for_deep_copy the space includes all sub-buffers of
// a tuple.
int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy,
TransferManager* transfer_manager) {
int64 size = 0;
// TODO(b/33492279) remove once no devices represent result tuples as
// contiguous buffers.
if (allocate_space_for_deep_copy) {
ShapeUtil::ForEachSubshape(
shape, [&size, transfer_manager](const Shape& subshape,
const ShapeIndex& /*index*/) {
size += transfer_manager->GetByteSizeRequirement(subshape);
});
}
return size;
}
} // namespace
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
const ComputationHandle& computation, const ComputationHandle& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts, const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,

View File

@ -104,21 +104,6 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) {
return Status::OK(); return Status::OK();
} }
Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) {
// RecvDone doesn't create a new buffer but rather aliases its input (Recv)
// tuple element at {0} to its output.
return Status::OK();
}
Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) {
// Send creates new buffers for the top-level tuple and the context (tuple
// element at {1}). Tuple element at {0} is an alias of the Send operand, so
// we don't need to create a new Logical Buffer for that.
NewLogicalBuffer(send, /*index=*/{});
NewLogicalBuffer(send, /*index=*/{1});
return Status::OK();
}
Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) { Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) {
// A Tuple instruction only creates the top-level buffer. // A Tuple instruction only creates the top-level buffer.
NewLogicalBuffer(tuple, /*index=*/{}); NewLogicalBuffer(tuple, /*index=*/{});

View File

@ -60,8 +60,6 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault {
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleBitcast(HloInstruction* bitcast) override; Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleCopy(HloInstruction* copy) override; Status HandleCopy(HloInstruction* copy) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
Status HandleSelect(HloInstruction* select) override; Status HandleSelect(HloInstruction* select) override;
// A map from the buffer ID to the logical buffer // A map from the buffer ID to the logical buffer

View File

@ -272,6 +272,8 @@ class Service : public ServiceInterface {
// Create a Hlo module config for the given program shape and arguments. // Create a Hlo module config for the given program shape and arguments.
// execution_options is optional; if not given a default is used. // execution_options is optional; if not given a default is used.
// has_hybrid_result is used to initialize the same-named field in
// HloModuleConfig -- see that class for documentation.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig( StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape, const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes, tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,

View File

@ -31,7 +31,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/protobuf.h"
@ -771,12 +770,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of binary operation"));
lhs, tensorflow::strings::StrCat("lhs of binary operation ", TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of binary operation"));
BinaryOperation_Name(operation))));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
rhs, tensorflow::strings::StrCat("rhs of binary operation ",
BinaryOperation_Name(operation))));
switch (operation) { switch (operation) {
case BINOP_DOT: case BINOP_DOT:
return InferDotOpShape(lhs, rhs); return InferDotOpShape(lhs, rhs);
@ -1948,10 +1943,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
!std::is_permutation(dimensions.begin(), dimensions.end(), !std::is_permutation(dimensions.begin(), dimensions.end(),
indices.begin())) { indices.begin())) {
return InvalidArgument( return InvalidArgument(
"Reshape dimensions [%s] are not a permutation of the operand " "Reshape dimensions not a permutation of the operand dimensions.");
"dimensions (operand shape is %s).",
tensorflow::str_util::Join(dimensions, ",").c_str(),
ShapeUtil::HumanString(operand).c_str());
} }
return inferred_shape; return inferred_shape;

View File

@ -253,64 +253,6 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
return Status::OK(); return Status::OK();
} }
Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
// RecvDone aliases its input (Recv) tuple element {0} to its output.
PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
const PointsToSet& operand_points_to_set =
GetPointsToSet(recv_done->operand(0));
// Recursively copy the points to set of the operand tuple {0}.
points_to_set.ForEachMutableElement(
[this, &points_to_set, &operand_points_to_set](
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
ShapeIndex src_index({0});
for (auto element : index) {
src_index.push_back(element);
}
*buffers = operand_points_to_set.element(src_index);
for (auto& tuple_source :
operand_points_to_set.tuple_sources(src_index)) {
points_to_set.add_tuple_source(index, tuple_source);
}
});
return Status::OK();
}
Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
// Send creates a tuple of {aliased operand, U32 context}.
PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
// Creates the points to set for the tuple and its element at {1}.
auto top_buffer = points_to_set.mutable_element(ShapeIndex({}));
top_buffer->push_back(
&logical_buffer_analysis_->GetBuffer(send, ShapeIndex({})));
points_to_set.add_tuple_source({}, send);
auto context_buffer = points_to_set.mutable_element(ShapeIndex({1}));
context_buffer->push_back(
&logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
// Recursively copy the points to set of the operand to output tuple {0}.
const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
operand_points_to_set.ForEachElement(
[&points_to_set, &operand_points_to_set](
const ShapeIndex& src_index,
const PointsToSet::BufferList& points_to) {
ShapeIndex target_index({0});
for (auto element : src_index) {
target_index.push_back(element);
}
*points_to_set.mutable_element(target_index) = points_to;
for (HloInstruction* tuple :
operand_points_to_set.tuple_sources(src_index)) {
points_to_set.add_tuple_source(target_index, tuple);
}
});
return Status::OK();
}
Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) { Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands()); tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands());
PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple); PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);

View File

@ -251,8 +251,6 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleBitcast(HloInstruction* bitcast) override; Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleCopy(HloInstruction* copy) override; Status HandleCopy(HloInstruction* copy) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
Status HandleSelect(HloInstruction* select) override; Status HandleSelect(HloInstruction* select) override;
string ToString() const; string ToString() const;

View File

@ -313,51 +313,6 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) {
{constant1, constant2, copy}); {constant1, constant2, copy});
} }
TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
// Send forwards its operand to the output tuple at {0}.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto send = builder.AddInstruction(
HloInstruction::CreateSend(constant, /*channel_id=*/0));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
BuildModuleAndRunAnalysis(builder.Build());
EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send).IsAmbiguous());
EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send).IsDistinct());
EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send_done).IsAmbiguous());
EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send_done).IsDistinct());
ExpectHasTopLevelBuffers(
points_to_analysis_->GetPointsToSet(send).element({}), {send});
ExpectHasTopLevelBuffers(
points_to_analysis_->GetPointsToSet(send).element({0}), {constant});
ExpectHasTopLevelBuffers(
points_to_analysis_->GetPointsToSet(send_done).CreateFlattenedSet(),
{send_done});
ExpectHasBufferAliases(constant, {}, {{constant, {}}, {send, {0}}});
}
TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
// RecvDone forwards its operand tuple element at {0} to the output.
auto builder = HloComputation::Builder(TestName());
auto recv = builder.AddInstruction(HloInstruction::CreateRecv(
ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
BuildModuleAndRunAnalysis(builder.Build());
EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv).IsAmbiguous());
EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv).IsDistinct());
EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv_done).IsAmbiguous());
EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv_done).IsDistinct());
ExpectHasTopLevelBuffers(
points_to_analysis_->GetPointsToSet(recv).element({}), {recv});
ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}});
}
TEST_F(TuplePointsToAnalysisTest, TupleSelect) { TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
// Select from two different tuples. This should create an ambiguous points to // Select from two different tuples. This should create an ambiguous points to
// set containing the union of both sides. // set containing the union of both sides.

View File

@ -2927,9 +2927,8 @@ void ComputationLowerer::Visit(
case OpRequest::kRecvRequest: { case OpRequest::kRecvRequest: {
const RecvRequest& recv_request = request.request().recv_request(); const RecvRequest& recv_request = request.request().recv_request();
HloInstruction* recv = add_instruction(HloInstruction::CreateRecv( hlo_instruction = add_instruction(HloInstruction::CreateRecv(
request.output_shape(), recv_request.channel_handle().handle())); request.output_shape(), recv_request.channel_handle().handle()));
hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv));
break; break;
} }
@ -3121,9 +3120,8 @@ void ComputationLowerer::Visit(
case OpRequest::kSendRequest: { case OpRequest::kSendRequest: {
const SendRequest& send_request = request.request().send_request(); const SendRequest& send_request = request.request().send_request();
HloInstruction* operand = lookup_instruction(send_request.operand()); HloInstruction* operand = lookup_instruction(send_request.operand());
HloInstruction* send = add_instruction(HloInstruction::CreateSend( hlo_instruction = add_instruction(HloInstruction::CreateSend(
operand, send_request.channel_handle().handle())); operand, send_request.channel_handle().handle()));
hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send));
break; break;
} }

View File

@ -58,9 +58,7 @@ static bool ContainsSendOrRecv(const HloComputation* comp) {
static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { static bool IsOrContainsSendOrRecv(const HloInstruction* instr) {
if (instr->opcode() == HloOpcode::kSend || if (instr->opcode() == HloOpcode::kSend ||
instr->opcode() == HloOpcode::kSendDone || instr->opcode() == HloOpcode::kRecv) {
instr->opcode() == HloOpcode::kRecv ||
instr->opcode() == HloOpcode::kRecvDone) {
return true; return true;
} }
for (const auto& subcomp : instr->called_computations()) { for (const auto& subcomp : instr->called_computations()) {

View File

@ -144,11 +144,10 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) {
auto* while_op = computation->root_instruction(); auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body(); auto* while_body = while_op->while_body();
auto* send = while_body->AddInstruction(HloInstruction::CreateSend( while_body->AddInstruction(HloInstruction::CreateSend(
while_body->AddInstruction( while_body->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))), HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
/*channel_id=*/0)); /*channel_id=*/0));
while_body->AddInstruction(HloInstruction::CreateSendDone(send));
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
} }
@ -157,10 +156,9 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) {
auto* while_op = computation->root_instruction(); auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body(); auto* while_body = while_op->while_body();
auto* recv = while_body->AddInstruction( while_body->AddInstruction(
HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}),
/*channel_id=*/0)); /*channel_id=*/0));
while_body->AddInstruction(HloInstruction::CreateRecvDone(recv));
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
} }

View File

@ -116,7 +116,6 @@ class ShapeTree {
ShapeTree(const Shape* shape, const T& init_value); ShapeTree(const Shape* shape, const T& init_value);
ShapeTree(const ShapeTree& other) { *this = other; } ShapeTree(const ShapeTree& other) { *this = other; }
ShapeTree(ShapeTree&&) = default;
ShapeTree& operator=(const ShapeTree& other) { ShapeTree& operator=(const ShapeTree& other) {
root_ = other.root_; root_ = other.root_;
@ -133,8 +132,6 @@ class ShapeTree {
return *this; return *this;
} }
ShapeTree& operator=(ShapeTree&& other) = default;
// Returns the data element associated with the array in the shape at the // Returns the data element associated with the array in the shape at the
// given index (see ShapeUtil::GetSubshape for how indexes are defined). // given index (see ShapeUtil::GetSubshape for how indexes are defined).
const T& element(const ShapeIndex& index) const; const T& element(const ShapeIndex& index) const;

View File

@ -263,7 +263,6 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
case S32: case S32:
case S64: case S64:
case F16: case F16:
case BF16:
case F32: case F32:
case F64: case F64:
return true; return true;

Some files were not shown because too many files have changed in this diff Show More