Merge branch 'master' into tflite
This commit is contained in:
commit
31c345a9fb
8
.gitignore
vendored
8
.gitignore
vendored
@ -24,10 +24,10 @@ Pods
|
||||
Podfile.lock
|
||||
*.pbxproj
|
||||
*.xcworkspacedata
|
||||
/tensorflow/contrib/lite/tools/make/downloads/**
|
||||
/tensorflow/contrib/lite/gen/**
|
||||
/tensorflow/contrib/lite/examples/ios/simple/data/*.txt
|
||||
/tensorflow/contrib/lite/examples/ios/simple/data/*.tflite
|
||||
/tensorflow/lite/tools/make/downloads/**
|
||||
/tensorflow/lite/gen/**
|
||||
/tensorflow/lite/examples/ios/simple/data/*.txt
|
||||
/tensorflow/lite/examples/ios/simple/data/*.tflite
|
||||
xcuserdata/**
|
||||
/api_init_files_list.txt
|
||||
/estimator_api_init_files_list.txt
|
||||
|
2
BUILD
2
BUILD
@ -2,5 +2,7 @@ exports_files(
|
||||
[
|
||||
"LICENSE",
|
||||
"ACKNOWLEDGEMENTS",
|
||||
"configure",
|
||||
"configure.py",
|
||||
],
|
||||
)
|
||||
|
@ -258,8 +258,8 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A
|
||||
* Update `tf.keras` to the Keras 2.1.6 API.
|
||||
* Added [`tf.keras.layers.CuDNNGRU`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNGRU) and [`tf.keras.layers.CuDNNLSTM`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNLSTM) layers. [Try it](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb?linkId=53292082).
|
||||
* Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/boosted_trees).
|
||||
* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/lite)
|
||||
for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/README.md)
|
||||
* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/lite)
|
||||
for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/toco/README.md)
|
||||
has been expanded, and the command line interface (AKA: `toco`, `tflite_convert`) is once again
|
||||
included in the standard `pip` installation.
|
||||
* Improved data-loading and text processing with:
|
||||
@ -562,7 +562,7 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田
|
||||
## Major Features And Improvements
|
||||
* [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager)
|
||||
preview version is now available.
|
||||
* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/lite)
|
||||
* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/lite)
|
||||
dev preview is now available.
|
||||
* CUDA 9.0 and cuDNN 7 support.
|
||||
* Accelerated Linear Algebra (XLA):
|
||||
@ -909,7 +909,7 @@ See also [TensorBoard 0.1.4](https://github.com/tensorflow/tensorboard/releases/
|
||||
* Adds tf.contrib.nn.rank_sampled_softmax_loss, a sampled-softmax variant that can improve rank loss.
|
||||
* `tf.contrib.metrics`.{streaming_covariance,streaming_pearson_correlation} modified to return nan when they have seen less or equal to 1 unit of weight.
|
||||
* Adds time series models to contrib. See contrib/timeseries/README.md for details.
|
||||
* Adds FULLY_CONNECTED Op to tensorflow/contrib/lite/schema.fbs
|
||||
* Adds FULLY_CONNECTED Op to tensorflow/lite/schema.fbs
|
||||
|
||||
## Known Issues
|
||||
* Tensorflow_gpu compilation fails with Bazel 0.5.3.
|
||||
|
25
configure.py
25
configure.py
@ -1418,11 +1418,16 @@ def set_mpi_home(environ_cp):
|
||||
def valid_mpi_path(mpi_home):
|
||||
exists = (
|
||||
os.path.exists(os.path.join(mpi_home, 'include')) and
|
||||
os.path.exists(os.path.join(mpi_home, 'lib')))
|
||||
(os.path.exists(os.path.join(mpi_home, 'lib')) or
|
||||
os.path.exists(os.path.join(mpi_home, 'lib64')) or
|
||||
os.path.exists(os.path.join(mpi_home, 'lib32'))))
|
||||
if not exists:
|
||||
print('Invalid path to the MPI Toolkit. %s or %s cannot be found' %
|
||||
(os.path.join(mpi_home, 'include'),
|
||||
os.path.exists(os.path.join(mpi_home, 'lib'))))
|
||||
print(
|
||||
'Invalid path to the MPI Toolkit. %s or %s or %s or %s cannot be found'
|
||||
% (os.path.join(mpi_home, 'include'),
|
||||
os.path.exists(os.path.join(mpi_home, 'lib')),
|
||||
os.path.exists(os.path.join(mpi_home, 'lib64')),
|
||||
os.path.exists(os.path.join(mpi_home, 'lib32'))))
|
||||
return exists
|
||||
|
||||
_ = prompt_loop_or_load_from_env(
|
||||
@ -1463,8 +1468,17 @@ def set_other_mpi_vars(environ_cp):
|
||||
if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')):
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so')
|
||||
elif os.path.exists(os.path.join(mpi_home, 'lib64/libmpi.so')):
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'lib64/libmpi.so'), 'third_party/mpi/libmpi.so')
|
||||
elif os.path.exists(os.path.join(mpi_home, 'lib32/libmpi.so')):
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'lib32/libmpi.so'), 'third_party/mpi/libmpi.so')
|
||||
|
||||
else:
|
||||
raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home)
|
||||
raise ValueError(
|
||||
'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' %
|
||||
mpi_home, mpi_home, mpi_home)
|
||||
|
||||
|
||||
def set_system_libs_flag(environ_cp):
|
||||
@ -1681,4 +1695,3 @@ def main():
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
@ -55,4 +55,10 @@ except NameError:
|
||||
# does not have 'python', 'core' directories. Then, it will be copied
|
||||
# to tensorflow/ which does have these two directories.
|
||||
pass
|
||||
# Similarly for compiler. Do it separately to make sure we do this even if the
|
||||
# others don't exist.
|
||||
try:
|
||||
del compiler
|
||||
except NameError:
|
||||
pass
|
||||
# pylint: enable=undefined-variable
|
||||
|
@ -63,4 +63,10 @@ except NameError:
|
||||
# does not have 'python', 'core' directories. Then, it will be copied
|
||||
# to tensorflow/ which does have these two directories.
|
||||
pass
|
||||
# Similarly for compiler. Do it separately to make sure we do this even if the
|
||||
# others don't exist.
|
||||
try:
|
||||
del compiler
|
||||
except NameError:
|
||||
pass
|
||||
# pylint: enable=undefined-variable
|
||||
|
@ -199,6 +199,7 @@ tf_cuda_cc_test(
|
||||
size = "small",
|
||||
srcs = ["c_api_test.cc"],
|
||||
data = [
|
||||
":test_op1.so",
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
kernels = [":test_op_kernel"],
|
||||
@ -283,8 +284,8 @@ tf_cc_test(
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "test_op.so",
|
||||
srcs = ["test_op.cc"],
|
||||
name = "test_op1.so",
|
||||
srcs = ["test_op1.cc"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
|
@ -8775,3 +8775,28 @@ void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
|
||||
tensorflow::DeviceType(device_type), builder->BuildNodeDef(),
|
||||
/* def = */ nullptr, /* kernel_class_name = */ nullptr);
|
||||
}
|
||||
|
||||
const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
|
||||
TF_Status* status) {
|
||||
const tensorflow::OpDef* op_def = nullptr;
|
||||
status->status =
|
||||
tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
if (input_index >= op_def->input_arg_size() || input_index < 0) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
input_index, " out of range for ", op_name);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index];
|
||||
|
||||
if (input_arg.number_attr().empty()) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
op_name, " does not have number_attr() defined.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// The returned string is owned by OpRegistry, so liveness is not a concern.
|
||||
return input_arg.number_attr().c_str();
|
||||
}
|
||||
|
@ -202,6 +202,13 @@ TF_CAPI_EXPORT extern void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder,
|
||||
TF_CAPI_EXPORT extern void TF_AttrBuilderCheckCanRunOnDevice(
|
||||
TF_AttrBuilder* builder, const char* device_type, TF_Status* status);
|
||||
|
||||
// For argument number input_index, fetch the corresponding number_attr that
|
||||
// needs to be updated with the argument length of the input list.
|
||||
// Returns nullptr if there is any problem like op_name is not found, or the
|
||||
// argument does not support this attribute type.
|
||||
TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput(
|
||||
const char* op_name, int input_index, TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -187,15 +187,26 @@ TEST(CAPI, LibraryLoadFunctions) {
|
||||
// tf_cuda_cc_test() bazel rule and remove the next line.
|
||||
if (!GPUDeviceName().empty()) return;
|
||||
|
||||
// Load the library.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Library* lib =
|
||||
TF_LoadLibrary("tensorflow/c/test_op.so", status);
|
||||
TF_Code code = TF_GetCode(status);
|
||||
string status_msg(TF_Message(status));
|
||||
TF_DeleteStatus(status);
|
||||
ASSERT_EQ(TF_OK, code) << status_msg;
|
||||
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
{
|
||||
// Load the library.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Library* lib =
|
||||
TF_LoadLibrary("tensorflow/c/test_op1.so", status);
|
||||
TF_Code code = TF_GetCode(status);
|
||||
string status_msg(TF_Message(status));
|
||||
TF_DeleteStatus(status);
|
||||
ASSERT_EQ(TF_OK, code) << status_msg;
|
||||
|
||||
// Test op list.
|
||||
TF_Buffer op_list_buf = TF_GetOpList(lib);
|
||||
tensorflow::OpList op_list;
|
||||
EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
|
||||
ASSERT_EQ(op_list.op_size(), 1);
|
||||
EXPECT_EQ("TestCApi1", op_list.op(0).name());
|
||||
TF_DeleteLibraryHandle(lib);
|
||||
}
|
||||
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
{
|
||||
TF_Buffer* op_list_buffer = TF_GetAllOpList();
|
||||
tensorflow::OpList op_list;
|
||||
@ -210,19 +221,6 @@ TEST(CAPI, LibraryLoadFunctions) {
|
||||
EXPECT_TRUE(found);
|
||||
TF_DeleteBuffer(op_list_buffer);
|
||||
}
|
||||
|
||||
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
{
|
||||
// Test op list.
|
||||
TF_Buffer op_list_buf = TF_GetOpList(lib);
|
||||
tensorflow::OpList op_list;
|
||||
EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
|
||||
ASSERT_EQ(op_list.op_size(), 1);
|
||||
EXPECT_EQ("TestCApi", op_list.op(0).name());
|
||||
}
|
||||
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
|
||||
TF_DeleteLibraryHandle(lib);
|
||||
}
|
||||
|
||||
void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
@ -2349,14 +2347,8 @@ TEST(TestApiDef, TestCreateApiDef) {
|
||||
// tf_cuda_cc_test() bazel rule and remove the next line.
|
||||
if (!GPUDeviceName().empty()) return;
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Library* lib =
|
||||
TF_LoadLibrary("tensorflow/c/test_op.so", status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TF_Buffer* op_list_buf = TF_GetAllOpList();
|
||||
status = TF_NewStatus();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto* api_def_map = TF_NewApiDefMap(op_list_buf, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
@ -2376,7 +2368,6 @@ TEST(TestApiDef, TestCreateApiDef) {
|
||||
TF_DeleteBuffer(api_def_buf);
|
||||
TF_DeleteApiDefMap(api_def_map);
|
||||
TF_DeleteBuffer(op_list_buf);
|
||||
TF_DeleteLibraryHandle(lib);
|
||||
}
|
||||
|
||||
TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
|
||||
@ -2384,14 +2375,8 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
|
||||
// tf_cuda_cc_test() bazel rule and remove the next line.
|
||||
if (!GPUDeviceName().empty()) return;
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Library* lib =
|
||||
TF_LoadLibrary("tensorflow/c/test_op.so", status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TF_Buffer* op_list_buf = TF_GetAllOpList();
|
||||
status = TF_NewStatus();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto* api_def_map = TF_NewApiDefMap(op_list_buf, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
@ -2422,7 +2407,6 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
|
||||
TF_DeleteBuffer(api_def_buf);
|
||||
TF_DeleteApiDefMap(api_def_map);
|
||||
TF_DeleteBuffer(op_list_buf);
|
||||
TF_DeleteLibraryHandle(lib);
|
||||
}
|
||||
|
||||
class DummyKernel : public tensorflow::OpKernel {
|
||||
|
@ -69,7 +69,7 @@ tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
hdrs = ["c_api_internal.h"],
|
||||
visibility = [
|
||||
"//learning/deepmind/courier:__pkg__",
|
||||
"//learning/deepmind/courier:__subpackages__",
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
|
@ -79,10 +79,6 @@ struct TFE_TensorHandle {
|
||||
tensorflow::Device* op_device)
|
||||
: handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {}
|
||||
|
||||
TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype,
|
||||
tensorflow::EagerContext* ctx)
|
||||
: handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {}
|
||||
|
||||
TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {}
|
||||
|
||||
tensorflow::TensorHandle* handle;
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,17 +13,11 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
namespace tensorflow {
|
||||
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
REGISTER_OP("TestCApi1").Doc(R"doc(Used to test C API)doc");
|
||||
|
||||
using StringPiece = absl::string_view;
|
||||
|
||||
} // namespace port
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
|
||||
} // namespace tensorflow
|
@ -170,6 +170,7 @@ cc_library_with_android_deps(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -516,6 +517,8 @@ tf_gen_op_wrappers_cc(
|
||||
":array_ops",
|
||||
":const_op",
|
||||
":math_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -190,11 +190,13 @@ cc_library(
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"//tensorflow/core/kernels:sendrecv_ops",
|
||||
"//tensorflow/core/kernels:shape_ops",
|
||||
"//tensorflow/core/kernels:stack",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"//tensorflow/core/kernels/data:generator_dataset_op",
|
||||
"//tensorflow/core/kernels/data:iterator_ops",
|
||||
"//tensorflow/core/kernels/data:prefetch_dataset_op",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
@ -240,6 +242,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
@ -499,6 +502,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -214,7 +214,8 @@ Status NodeRequiresCompilation(Node* n, bool* result) {
|
||||
return errors::Internal("Could not find compilation device ",
|
||||
device_type.type());
|
||||
}
|
||||
*result = registration->requires_compilation;
|
||||
*result = registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -525,7 +525,6 @@ Predicate* PredicateFactory::MakeAndOrImpl(
|
||||
op->GetOperands().begin(),
|
||||
op->GetOperands().end());
|
||||
} else {
|
||||
std::vector<Predicate*> sub_ops_intersection;
|
||||
common_inner_operands.clear();
|
||||
absl::c_copy_if(op->GetOperands(),
|
||||
std::back_inserter(common_inner_operands),
|
||||
|
@ -127,7 +127,8 @@ InductionVarInfo CreateInductionVariable(const Scope& root,
|
||||
Output loop_cond =
|
||||
ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr);
|
||||
ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
|
||||
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
|
||||
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
|
||||
latch.output_false);
|
||||
Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"),
|
||||
latch.output_true, increment_by);
|
||||
Output next_iteration =
|
||||
@ -191,7 +192,8 @@ DependentInductionVar CreateDependentLoopInvariantValue(
|
||||
value, frame_name);
|
||||
ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value});
|
||||
ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
|
||||
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
|
||||
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
|
||||
latch.output_false);
|
||||
Output next_iteration = ops::NextIteration(
|
||||
root.WithOpName(prefix + "/next_iteration"), latch.output_true);
|
||||
CHECK(root.graph()
|
||||
|
@ -1122,8 +1122,11 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
|
||||
fdef);
|
||||
}
|
||||
|
||||
if (!reuse_existing_functions || library->Find(name) == nullptr) {
|
||||
const FunctionDef* original_fdef = library->Find(name);
|
||||
if (!reuse_existing_functions || original_fdef == nullptr) {
|
||||
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
|
||||
} else if (!FunctionDefsEqual(*original_fdef, fdef)) {
|
||||
TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -51,6 +51,12 @@ xla::StatusOr<Node*> AddHostComputeKeyPlaceholder(
|
||||
return n;
|
||||
}
|
||||
|
||||
// Returns if the node is a XLA computation key placeholder.
|
||||
bool IsKeyPlaceholderNode(const Node& n) {
|
||||
return n.type_string() == "Placeholder" &&
|
||||
absl::EndsWith(n.name(), "_key_placeholder");
|
||||
}
|
||||
|
||||
// Returns nodes with given type.
|
||||
std::vector<Node*> GatherNodesWithType(const Graph& g, const string& type) {
|
||||
std::vector<Node*> result;
|
||||
@ -107,6 +113,8 @@ xla::StatusOr<Node*> BuildRecvAtHostNode(
|
||||
xla::StatusOr<Node*> ReplaceArgNodesWithRecvAtHostNode(
|
||||
Graph* g, const string& oc_cluster_name,
|
||||
std::vector<DataType>* recv_at_host_dtypes, Node* key_placeholder) {
|
||||
// TODO(b/77601805): use out nodes for source node, instead of traversing all
|
||||
// nodes.
|
||||
std::vector<Node*> arg_nodes = GatherNodesWithType(*g, "_Arg");
|
||||
TF_RETURN_IF_ERROR(GetArgDataTypes(arg_nodes, recv_at_host_dtypes));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -218,6 +226,8 @@ xla::StatusOr<Node*> BuildSendFromHostNode(
|
||||
xla::StatusOr<Node*> ReplaceRetNodesWithSendFromHostNode(
|
||||
Graph* g, const string& oc_cluster_name,
|
||||
std::vector<DataType>* send_from_host_dtypes, Node* key_placeholder) {
|
||||
// TODO(b/77601805): use in nodes for sink node, instead of traversing all
|
||||
// nodes.
|
||||
std::vector<Node*> ret_nodes = GatherNodesWithType(*g, "_Retval");
|
||||
TF_RETURN_IF_ERROR(GetRetDataTypes(ret_nodes, send_from_host_dtypes));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -258,7 +268,7 @@ absl::optional<std::vector<PartialTensorShape>> GetInferredInputShapes(
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
const PartialTensorShape shape = shapes[e->dst_input()];
|
||||
const PartialTensorShape shape = shapes[e->src_output()];
|
||||
if (!shape.IsFullyDefined()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
@ -411,8 +421,7 @@ Status ConstructHostGraph(
|
||||
if (node_map.find(n) != node_map.end()) {
|
||||
// Already copied this node.
|
||||
copy = node_map.at(n);
|
||||
} else if (n->type_string() == "Placeholder" &&
|
||||
absl::EndsWith(n->name(), "_key_placeholder")) {
|
||||
} else if (IsKeyPlaceholderNode(*n)) {
|
||||
// Change a).
|
||||
copy = key_placeholder;
|
||||
node_map[n] = copy;
|
||||
@ -691,8 +700,7 @@ Status RewriteOutsideCompilationSubgraphFn::operator()(
|
||||
|
||||
// Step 4: add XLA cluster and outside compilation attr.
|
||||
for (Node* n : (*graph)->nodes()) {
|
||||
if (n->type_string() == "Placeholder" &&
|
||||
absl::EndsWith(n->name(), "_key_placeholder")) {
|
||||
if (IsKeyPlaceholderNode(*n)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -221,8 +221,8 @@ Status ConvertTensorFlowSliceToStaticShapedSlice(
|
||||
.WithOpName("static_shaped_slice"),
|
||||
slice_inputs_int64.input, slice_inputs_int64.begin, slice_size)
|
||||
.node();
|
||||
std::vector<int> compile_time_const_inputs;
|
||||
compile_time_const_inputs.push_back(2);
|
||||
std::vector<string> compile_time_const_inputs;
|
||||
compile_time_const_inputs.push_back("size");
|
||||
(*result)->AddAttr(kXlaCompileTimeConstantInputsAttr,
|
||||
compile_time_const_inputs);
|
||||
return status;
|
||||
@ -314,15 +314,18 @@ Status FindAndRewriteSlices(Graph* g, bool* changed) {
|
||||
|
||||
Status IncreaseDynamismForAutoJitPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
legacy_flags::MarkForCompilationPassFlags* flags =
|
||||
legacy_flags::GetMarkForCompilationPassFlags();
|
||||
if (flags->tf_xla_clustering_debug) {
|
||||
dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass",
|
||||
**options.graph, options.flib_def);
|
||||
}
|
||||
|
||||
bool changed;
|
||||
TF_RETURN_IF_ERROR(FindAndRewriteSlices(options.graph->get(), &changed));
|
||||
if (changed) {
|
||||
legacy_flags::MarkForCompilationPassFlags* flags =
|
||||
legacy_flags::GetMarkForCompilationPassFlags();
|
||||
if (flags->tf_xla_clustering_debug) {
|
||||
dump_graph::DumpGraphToFile("increase_dynamism_for_auto_jit_pass",
|
||||
**options.graph, options.flib_def);
|
||||
}
|
||||
if (changed && flags->tf_xla_clustering_debug) {
|
||||
dump_graph::DumpGraphToFile("increase_dynamism_for_auto_jit_pass",
|
||||
**options.graph, options.flib_def);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -129,8 +129,8 @@ TEST(SliceToDynamicSliceRewriteTest, Basic) {
|
||||
Op("ConcatV2"), AssignedDevice(kHostName),
|
||||
Inputs(m_slice_size_0, Const(static_cast<int64>(500)), Const(zero_32))));
|
||||
|
||||
std::vector<int> compile_time_constant_inputs;
|
||||
compile_time_constant_inputs.push_back(2);
|
||||
std::vector<string> compile_time_constant_inputs;
|
||||
compile_time_constant_inputs.push_back("size");
|
||||
auto m_dynamic_slice = NodeWith(
|
||||
Op("Slice"), AssignedDevice(kDeviceName),
|
||||
Attr(kXlaCompileTimeConstantInputsAttr, compile_time_constant_inputs),
|
||||
|
@ -39,12 +39,22 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
|
||||
// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
|
||||
// in error case, it returns RET instead of void.
|
||||
#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
|
||||
do { \
|
||||
::tensorflow::Status _s(__VA_ARGS__); \
|
||||
if (!TF_PREDICT_TRUE(_s.ok())) { \
|
||||
(CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
|
||||
return RET; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
Status PlatformInfoFromContext(OpKernelConstruction* ctx,
|
||||
XlaPlatformInfo* result) {
|
||||
XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
DeviceType device_type = ctx->device_type();
|
||||
se::Platform::Id platform_id = nullptr;
|
||||
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
||||
@ -76,16 +86,16 @@ Status PlatformInfoFromContext(OpKernelConstruction* ctx,
|
||||
}
|
||||
|
||||
if (!device_allocator) {
|
||||
TF_ASSIGN_OR_RETURN(se::Platform* const platform,
|
||||
se::MultiPlatformManager::PlatformWithId(platform_id));
|
||||
xla::StatusOr<se::Platform*> maybe_platform =
|
||||
se::MultiPlatformManager::PlatformWithId(platform_id);
|
||||
OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
|
||||
|
||||
xla_allocator = absl::make_unique<XlaAllocator>(
|
||||
platform, ctx->device()->GetAllocator({}));
|
||||
maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
|
||||
}
|
||||
|
||||
*result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
|
||||
std::move(xla_allocator), device_allocator);
|
||||
|
||||
return Status::OK();
|
||||
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
|
||||
std::move(xla_allocator), device_allocator);
|
||||
}
|
||||
|
||||
// A closure describing how to run a compiled version of a TensorFlow function.
|
||||
@ -179,9 +189,8 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||
: OpKernel(ctx),
|
||||
constants_(constants),
|
||||
resources_(resources),
|
||||
function_(function) {
|
||||
OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
|
||||
}
|
||||
function_(function),
|
||||
platform_info_(PlatformInfoFromContext(ctx)) {}
|
||||
|
||||
static Status BuildCompilationCache(OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
@ -333,18 +342,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
|
||||
// in error case, it returns RET instead of void.
|
||||
#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
|
||||
do { \
|
||||
::tensorflow::Status _s(__VA_ARGS__); \
|
||||
if (!TF_PREDICT_TRUE(_s.ok())) { \
|
||||
(CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
|
||||
return RET; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Helper static functions to construct parameters for
|
||||
// XlaLocalLaunchBase constructor from OpKernelConstruction.
|
||||
std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
|
||||
@ -381,7 +378,12 @@ NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
|
||||
return *func;
|
||||
}
|
||||
|
||||
#undef OP_REQUIRES_OK_RETURN
|
||||
bool MustCompileAttr(OpKernelConstruction* ctx) {
|
||||
bool must_compile;
|
||||
OP_REQUIRES_OK_RETURN(ctx, false,
|
||||
ctx->GetAttr("must_compile", &must_compile));
|
||||
return must_compile;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
|
||||
@ -396,10 +398,9 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx),
|
||||
constants_(ConstantsVector(ctx)),
|
||||
resources_(ResourcesVector(ctx)),
|
||||
function_(FunctionAttr(ctx)) {
|
||||
OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("must_compile", &must_compile_));
|
||||
}
|
||||
function_(FunctionAttr(ctx)),
|
||||
platform_info_(PlatformInfoFromContext(ctx)),
|
||||
must_compile_(MustCompileAttr(ctx)) {}
|
||||
|
||||
void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(3) << "XlaCompileOp " << def().name()
|
||||
@ -409,13 +410,30 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
xla::LocalExecutable* executable;
|
||||
std::map<int, OptionalTensor> variables;
|
||||
|
||||
if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation) {
|
||||
bool cannot_compile_cluster;
|
||||
{
|
||||
mutex_lock guard(cannot_compile_cluster_mu_);
|
||||
cannot_compile_cluster = cannot_compile_cluster_;
|
||||
}
|
||||
|
||||
if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
|
||||
cannot_compile_cluster) {
|
||||
executable = nullptr;
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, CompileToLocalExecutable(
|
||||
ctx, function_, platform_info_, resources_,
|
||||
constants_, /*lazy=*/!must_compile_, &client,
|
||||
&variables, &kernel, &executable));
|
||||
Status status = CompileToLocalExecutable(
|
||||
ctx, function_, platform_info_, resources_, constants_,
|
||||
/*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
|
||||
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
|
||||
OP_REQUIRES_OK(ctx, status);
|
||||
}
|
||||
|
||||
if (status.code() == error::UNIMPLEMENTED) {
|
||||
LOG(WARNING) << "Compilation failed:" << status.ToString()
|
||||
<< ". Falling back to TF function call.";
|
||||
executable = nullptr;
|
||||
mutex_lock guard(cannot_compile_cluster_mu_);
|
||||
cannot_compile_cluster_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
AllocatorAttributes host_alloc_attrs;
|
||||
@ -452,9 +470,8 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
ctx->set_output(1, compilation_successful);
|
||||
}
|
||||
|
||||
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
|
||||
}
|
||||
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {}
|
||||
|
||||
void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(3) << "XlaRunOp " << def().name();
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
@ -33,6 +35,7 @@ namespace tensorflow {
|
||||
class XlaPlatformInfo {
|
||||
public:
|
||||
XlaPlatformInfo() : device_type_("") {}
|
||||
XlaPlatformInfo(XlaPlatformInfo&&) = default;
|
||||
explicit XlaPlatformInfo(const DeviceType device_type,
|
||||
se::Platform::Id platform_id,
|
||||
const XlaDevice::Metadata* xla_device_metadata,
|
||||
@ -110,12 +113,12 @@ class XlaLocalLaunchBase : public OpKernel {
|
||||
|
||||
protected:
|
||||
// Indexes of compile-time constant inputs
|
||||
std::vector<int> constants_;
|
||||
const std::vector<int> constants_;
|
||||
// Indexes of resource inputs
|
||||
std::vector<int> resources_;
|
||||
const std::vector<int> resources_;
|
||||
|
||||
NameAttrList function_;
|
||||
XlaPlatformInfo platform_info_;
|
||||
const NameAttrList function_;
|
||||
const XlaPlatformInfo platform_info_;
|
||||
};
|
||||
|
||||
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
|
||||
@ -144,15 +147,23 @@ class XlaCompileOp : public OpKernel {
|
||||
|
||||
private:
|
||||
// Indexes of compile-time constant inputs
|
||||
std::vector<int> constants_;
|
||||
const std::vector<int> constants_;
|
||||
// Indexes of resource inputs
|
||||
std::vector<int> resources_;
|
||||
const std::vector<int> resources_;
|
||||
|
||||
NameAttrList function_;
|
||||
const NameAttrList function_;
|
||||
|
||||
XlaPlatformInfo platform_info_;
|
||||
|
||||
bool must_compile_;
|
||||
const bool must_compile_;
|
||||
|
||||
// cannot_compile_cluster_ is set to true if XLA returns an Unimplemented
|
||||
// error when compiling the cluster this _XlaCompile is supposed to compile.
|
||||
// If `cannot_compile_cluster_` is true then we avoid compiling this cluster
|
||||
// on any future calls to _XlaCompile.
|
||||
bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false;
|
||||
|
||||
mutex cannot_compile_cluster_mu_;
|
||||
};
|
||||
|
||||
class XlaRunOp : public OpKernel {
|
||||
@ -162,7 +173,7 @@ class XlaRunOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
XlaPlatformInfo platform_info_;
|
||||
const XlaPlatformInfo platform_info_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -49,6 +49,25 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
// Aggregates information about what kinds of ops are allowed.
|
||||
struct OperationFilter {
|
||||
// Whether resource variable ops are allowed. We do not allow resource
|
||||
// variable ops in called functions (either as direct TF calls or as higher
|
||||
// order control flow ops) because we do not yet model their memory effects in
|
||||
// jit/resource_variable_safety_analysis.
|
||||
bool allow_resource_ops;
|
||||
|
||||
// Whether stateful RNG ops are allowed. XLA's RNG does not have the same
|
||||
// seeding behavior as TensorFlow's RNG (b/34749654). So we avoid
|
||||
// auto-clustering stateful RNG ops.
|
||||
bool allow_stateful_rng_ops;
|
||||
};
|
||||
|
||||
bool IsStatefulRandomOp(absl::string_view op_name) {
|
||||
return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
|
||||
op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
|
||||
op_name == "TruncatedNormal";
|
||||
}
|
||||
|
||||
bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
|
||||
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
|
||||
@ -101,7 +120,7 @@ const int kMaxRecursionDepth = 10;
|
||||
|
||||
bool IsCompilableCall(const NodeDef& call_def,
|
||||
const DeviceType& jit_device_type,
|
||||
bool allow_resource_ops, int depth,
|
||||
const OperationFilter& op_filter, int depth,
|
||||
FunctionLibraryRuntime* lib_runtime);
|
||||
|
||||
// Tests whether 'while_node' is a completely compilable loop.
|
||||
@ -109,7 +128,7 @@ bool IsCompilableCall(const NodeDef& call_def,
|
||||
// while loop to be compilable.
|
||||
bool IsCompilableWhile(const Node& while_node,
|
||||
const DeviceType& jit_device_type,
|
||||
bool allow_resource_ops, int depth,
|
||||
const OperationFilter& op_filter, int depth,
|
||||
FunctionLibraryRuntime* lib_runtime) {
|
||||
const NameAttrList* name_attr;
|
||||
NodeDef call;
|
||||
@ -124,7 +143,7 @@ bool IsCompilableWhile(const Node& while_node,
|
||||
call.set_name("while_cond");
|
||||
call.set_op(cond_func);
|
||||
*call.mutable_attr() = name_attr->attr();
|
||||
if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1,
|
||||
if (!IsCompilableCall(call, jit_device_type, op_filter, depth + 1,
|
||||
lib_runtime)) {
|
||||
VLOG(2) << "Rejecting While " << while_node.name()
|
||||
<< ": can't compile loop condition: " << cond_func;
|
||||
@ -140,7 +159,7 @@ bool IsCompilableWhile(const Node& while_node,
|
||||
call.set_name("while_body");
|
||||
call.set_op(body_func);
|
||||
*call.mutable_attr() = name_attr->attr();
|
||||
if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1,
|
||||
if (!IsCompilableCall(call, jit_device_type, op_filter, depth + 1,
|
||||
lib_runtime)) {
|
||||
VLOG(2) << "Rejecting While " << while_node.name()
|
||||
<< ": can't compile loop body: " << body_func;
|
||||
@ -154,7 +173,7 @@ bool IsCompilableWhile(const Node& while_node,
|
||||
// compilable.
|
||||
bool IsCompilableCall(const NodeDef& call_def,
|
||||
const DeviceType& jit_device_type,
|
||||
bool allow_resource_ops, int depth,
|
||||
const OperationFilter& op_filter, int depth,
|
||||
FunctionLibraryRuntime* lib_runtime) {
|
||||
if (depth > kMaxRecursionDepth) {
|
||||
VLOG(2) << "Rejecting " << call_def.op()
|
||||
@ -195,16 +214,20 @@ bool IsCompilableCall(const NodeDef& call_def,
|
||||
continue;
|
||||
if (node->type_string() == "While") {
|
||||
// Handle functional While loop.
|
||||
return IsCompilableWhile(*node, jit_device_type, allow_resource_ops,
|
||||
depth + 1, lib_runtime);
|
||||
return IsCompilableWhile(*node, jit_device_type, op_filter, depth + 1,
|
||||
lib_runtime);
|
||||
}
|
||||
if (!allow_resource_ops &&
|
||||
if (!op_filter.allow_resource_ops &&
|
||||
(HasResourceInput(*node) || HasResourceOutput(*node))) {
|
||||
return false;
|
||||
}
|
||||
if (!op_filter.allow_stateful_rng_ops &&
|
||||
IsStatefulRandomOp(node->type_string())) {
|
||||
return false;
|
||||
}
|
||||
if (!HasXLAKernel(*node, jit_device_type) &&
|
||||
!IsCompilableCall(node->def(), jit_device_type, allow_resource_ops,
|
||||
depth + 1, lib_runtime)) {
|
||||
!IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1,
|
||||
lib_runtime)) {
|
||||
VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op "
|
||||
<< node->name() << ": " << node->def().ShortDebugString();
|
||||
return false;
|
||||
@ -426,14 +449,28 @@ Status FindCompilationCandidates(
|
||||
CHECK(
|
||||
XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration));
|
||||
DeviceType jit_device_type(registration->compilation_device_name);
|
||||
|
||||
OperationFilter op_filter;
|
||||
op_filter.allow_resource_ops = registration->compile_resource_ops;
|
||||
op_filter.allow_stateful_rng_ops =
|
||||
(registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways);
|
||||
|
||||
if (!HasXLAKernel(*node, jit_device_type) &&
|
||||
!IsCompilableCall(node->def(), jit_device_type,
|
||||
registration->compile_resource_ops, 0, lib_runtime)) {
|
||||
!IsCompilableCall(node->def(), jit_device_type, op_filter, 0,
|
||||
lib_runtime)) {
|
||||
VLOG(2) << "Rejecting " << node->name() << ": unsupported op "
|
||||
<< node->type_string();
|
||||
continue;
|
||||
}
|
||||
if (!registration->compile_resource_ops &&
|
||||
|
||||
if (!op_filter.allow_stateful_rng_ops &&
|
||||
IsStatefulRandomOp(node->type_string())) {
|
||||
VLOG(2) << "Rejecting " << node->name() << ": stateful random operation";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!op_filter.allow_resource_ops &&
|
||||
(HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
|
||||
// We don't have a way of returning values of type DT_RESOURCE from XLA
|
||||
// computations so we avoid auto-clustering nodes producing DT_RESOURCE.
|
||||
@ -444,6 +481,7 @@ Status FindCompilationCandidates(
|
||||
<< node->type_string();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (compile_time_const_nodes[node->id()]) {
|
||||
const OpDef* op_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -501,9 +539,7 @@ Status FindCompilationCandidates(
|
||||
// registration->compile_resource_ops is true for XLA_CPU/XLA_GPU but not
|
||||
// for CPU/GPU.
|
||||
if (node->type_string() == "While" &&
|
||||
!IsCompilableWhile(*node, jit_device_type,
|
||||
registration->compile_resource_ops, 0,
|
||||
lib_runtime)) {
|
||||
!IsCompilableWhile(*node, jit_device_type, op_filter, 0, lib_runtime)) {
|
||||
continue;
|
||||
}
|
||||
// _Arg nodes in a top-level function represent feeds.
|
||||
@ -563,10 +599,12 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
|
||||
®istration));
|
||||
DeviceType jit_device_type(registration->compilation_device_name);
|
||||
|
||||
// We can always *compile* resource operations, even if we are sometimes
|
||||
// unable to auto-cluster them.
|
||||
const bool compile_resource_ops = true;
|
||||
return IsCompilableCall(ndef, jit_device_type, compile_resource_ops, 0, flr);
|
||||
// We can always *compile* resource operations and stateful RNGs, even if we
|
||||
// are sometimes unable to auto-cluster them.
|
||||
OperationFilter op_filter;
|
||||
op_filter.allow_resource_ops = true;
|
||||
op_filter.allow_stateful_rng_ops = true;
|
||||
return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr);
|
||||
}
|
||||
|
||||
Status MarkForCompilationPass::Run(
|
||||
@ -577,10 +615,8 @@ Status MarkForCompilationPass::Run(
|
||||
GetGlobalJitLevel(options);
|
||||
legacy_flags::MarkForCompilationPassFlags* flags =
|
||||
legacy_flags::GetMarkForCompilationPassFlags();
|
||||
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
|
||||
bool fusion_only = flags->tf_xla_fusion_only;
|
||||
|
||||
VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit;
|
||||
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
|
||||
VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit;
|
||||
const FunctionLibraryDefinition* fld = options.flib_def;
|
||||
@ -599,9 +635,6 @@ Status MarkForCompilationPass::Run(
|
||||
return false;
|
||||
}
|
||||
|
||||
// If this device requires a JIT, we must say yes.
|
||||
if (registration->requires_compilation) return true;
|
||||
|
||||
// If there is a _XlaCompile annotation, use its value.
|
||||
bool compile = false;
|
||||
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
|
||||
@ -638,18 +671,21 @@ Status MarkForCompilationPass::Run(
|
||||
return false;
|
||||
}
|
||||
|
||||
// Otherwise use the value of global_jit_level.
|
||||
// Ignore enable_jit_by_default if global jit compilation for CPU
|
||||
// is explicitly requested via tf_xla_cpu_global_jit flag
|
||||
bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU;
|
||||
// Otherwise use the value of global_jit_level and the device's
|
||||
// autoclustering policy.
|
||||
bool should_compile =
|
||||
(ignore_registration || registration->enable_jit_by_default) &&
|
||||
global_jit_level != OptimizerOptions::OFF;
|
||||
registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways ||
|
||||
(registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
|
||||
global_jit_level != OptimizerOptions::OFF);
|
||||
if (!should_compile) {
|
||||
if (global_jit_level == OptimizerOptions::OFF) {
|
||||
VLOG(2) << "Rejecting " << node->name() << ": global jit disabled.";
|
||||
} else {
|
||||
VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled.";
|
||||
VLOG(2)
|
||||
<< "Rejecting " << node->name()
|
||||
<< ": autoclustering for device only when requested explicitly.";
|
||||
}
|
||||
}
|
||||
return should_compile;
|
||||
@ -1037,12 +1073,10 @@ Status MarkForCompilationPass::RunImpl(
|
||||
XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration);
|
||||
|
||||
// Compile if this is a cluster of >= min_cluster_size compilable operators.
|
||||
// Also, always compile if the operator is placed on a device that requires
|
||||
// compilation, or if it contains at least one op that is marked for
|
||||
// Also, always compile if it contains at least one op that is marked for
|
||||
// compilation that is not an Identity op.
|
||||
if (effective_cluster_sizes[cluster] >= min_cluster_size ||
|
||||
(effective_cluster_sizes[cluster] > 0 && marked_for_compilation) ||
|
||||
registration->requires_compilation) {
|
||||
(effective_cluster_sizes[cluster] > 0 && marked_for_compilation)) {
|
||||
string& name = cluster_names[cluster];
|
||||
|
||||
if (name.empty()) {
|
||||
|
@ -923,9 +923,8 @@ TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
EXPECT_NE(clusters["test/shape_rng"], "");
|
||||
EXPECT_NE(clusters["test/reshape"], "");
|
||||
EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]);
|
||||
EXPECT_EQ(clusters["test/shape_rng"], "");
|
||||
EXPECT_EQ(clusters["test/reshape"], "");
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
|
||||
@ -1061,5 +1060,48 @@ TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) {
|
||||
// Improved Heuristics should prevent this probably.
|
||||
EXPECT_EQ(clusters["MatMulSource_dev0"], clusters["MatMul0_dev0"]);
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) {
|
||||
absl::string_view xla_cpu_device =
|
||||
"/job:worker/replica:0/task:0/device:XLA_CPU:0";
|
||||
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
|
||||
Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
|
||||
Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
|
||||
Output c = ops::Add(root.WithOpName("test/c"), a, b);
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
|
||||
n->set_assigned_device_name(string(xla_cpu_device));
|
||||
}
|
||||
}
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
EXPECT_NE(clusters["test/a"], "");
|
||||
EXPECT_NE(clusters["test/b"], "");
|
||||
EXPECT_NE(clusters["test/c"], "");
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, DontAutoclusterStatefulRandomOp) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
|
||||
Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
|
||||
Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
|
||||
Output c = ops::Add(root.WithOpName("test/c"), a, b);
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
EXPECT_EQ(clusters["test/a"], "");
|
||||
EXPECT_EQ(clusters["test/b"], "");
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -485,6 +485,16 @@ std::pair<string, AttrValue> impl::AttrLiteralHelper(
|
||||
return {int_list_attr.first, attr_value};
|
||||
}
|
||||
|
||||
std::pair<string, AttrValue> impl::AttrLiteralHelper(
|
||||
const std::pair<string, absl::Span<const string>>& string_list_attr) {
|
||||
AttrValue attr_value;
|
||||
AttrValue::ListValue* list = attr_value.mutable_list();
|
||||
for (string s : string_list_attr.second) {
|
||||
list->add_s(s);
|
||||
}
|
||||
return {string_list_attr.first, attr_value};
|
||||
}
|
||||
|
||||
impl::NodeMatcherProperties impl::Attr(std::pair<string, AttrValue> attr) {
|
||||
impl::NodeMatcherProperties props;
|
||||
props.set_attr(std::move(attr));
|
||||
|
@ -170,6 +170,9 @@ std::pair<string, AttrValue> AttrLiteralHelper(
|
||||
|
||||
std::pair<string, AttrValue> AttrLiteralHelper(
|
||||
const std::pair<string, absl::Span<const int>>& int_list_attr);
|
||||
|
||||
std::pair<string, AttrValue> AttrLiteralHelper(
|
||||
const std::pair<string, absl::Span<const string>>& string_list_attr);
|
||||
} // namespace impl
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -210,7 +210,8 @@ bool IsIntraClusterEdge(const Edge& edge) {
|
||||
bool IsMustCompileDevice(const DeviceType& device_type) {
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
if (XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
|
||||
return registration->requires_compilation;
|
||||
return registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
}
|
||||
|
||||
return false;
|
||||
|
@ -173,9 +173,7 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
XlaCompiler::Options options;
|
||||
options.device_type = metadata.jit_device_type();
|
||||
options.client = metadata.client();
|
||||
auto flib_def = absl::make_unique<FunctionLibraryDefinition>(
|
||||
OpRegistry::Global(), FunctionDefLibrary{});
|
||||
options.flib_def = flib_def.get();
|
||||
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
|
||||
options.shape_representation_fn = metadata.shape_representation_fn();
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
|
@ -42,8 +42,10 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
||||
|
||||
XlaOpRegistry::DeviceRegistration registration;
|
||||
registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
|
||||
registration.requires_compilation = !compile_on_demand;
|
||||
registration.enable_jit_by_default = false;
|
||||
registration.autoclustering_policy =
|
||||
compile_on_demand
|
||||
? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested
|
||||
: XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
registration.compile_resource_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
|
||||
|
||||
@ -60,7 +62,6 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
||||
options.device_name = DEVICE_XLA_CPU;
|
||||
options.device_ordinal = 0;
|
||||
options.compilation_device_name = DEVICE_CPU_XLA_JIT;
|
||||
options.transfer_as_literal = false;
|
||||
options.use_multiple_streams = false;
|
||||
auto device = absl::make_unique<XlaDevice>(session_options, options);
|
||||
devices->push_back(device.release());
|
||||
|
@ -201,12 +201,18 @@ XlaDevice::XlaDevice(const SessionOptions& session_options,
|
||||
jit_device_name_(options.compilation_device_name),
|
||||
platform_(options.platform),
|
||||
use_multiple_streams_(options.use_multiple_streams),
|
||||
transfer_as_literal_(options.transfer_as_literal),
|
||||
shape_representation_fn_(options.shape_representation_fn) {
|
||||
VLOG(1) << "Created XLA device " << options.compilation_device_name << " "
|
||||
<< this;
|
||||
thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device",
|
||||
/*num_threads=*/1));
|
||||
|
||||
// We have multiple device to device streams to allow for some concurrency
|
||||
// between transfers. The particular value of '4' is chosen fairly
|
||||
// arbitrarily. It may be necessary to make this tunable via
|
||||
// XlaDevice::Options.
|
||||
static constexpr int kNumDeviceToDeviceStreams = 4;
|
||||
device_to_device_streams_.resize(kNumDeviceToDeviceStreams);
|
||||
}
|
||||
|
||||
XlaDevice::~XlaDevice() {
|
||||
@ -274,8 +280,9 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
|
||||
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
|
||||
&need_new_device_context));
|
||||
|
||||
std::shared_ptr<se::Stream> host_to_device_stream = stream_;
|
||||
std::shared_ptr<se::Stream> device_to_host_stream = stream_;
|
||||
std::shared_ptr<se::Stream> host_to_device_stream;
|
||||
std::shared_ptr<se::Stream> device_to_host_stream;
|
||||
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams;
|
||||
if (use_multiple_streams_) {
|
||||
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
|
||||
&host_to_device_stream_,
|
||||
@ -283,8 +290,18 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
|
||||
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream",
|
||||
&device_to_host_stream_,
|
||||
&need_new_device_context));
|
||||
for (std::shared_ptr<se::Stream>& stream : device_to_device_streams_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
EnsureStreamOkLocked(backend, "device_to_device_stream", &stream,
|
||||
&need_new_device_context));
|
||||
}
|
||||
host_to_device_stream = host_to_device_stream_;
|
||||
device_to_host_stream = device_to_host_stream_;
|
||||
device_to_device_streams = device_to_device_streams_;
|
||||
} else {
|
||||
host_to_device_stream = stream_;
|
||||
device_to_host_stream = stream_;
|
||||
device_to_device_streams = {stream_};
|
||||
}
|
||||
|
||||
if (!need_new_device_context) {
|
||||
@ -302,8 +319,9 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
|
||||
// ensures that the streams remain live for the duration of a run, even if
|
||||
// an error is encountered and the streams are replaced with new ones.
|
||||
device_context_ = new XlaDeviceContext(
|
||||
stream_, host_to_device_stream, device_to_host_stream, client(),
|
||||
transfer_as_literal_, shape_representation_fn_, thread_pool_.get());
|
||||
stream_, std::move(host_to_device_stream),
|
||||
std::move(device_to_host_stream), std::move(device_to_device_streams),
|
||||
client(), shape_representation_fn_, thread_pool_.get());
|
||||
VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
|
||||
<< device_context_;
|
||||
|
||||
|
@ -108,11 +108,6 @@ class XlaDevice : public LocalDevice {
|
||||
// The name of the compilation device (e.g., "XLA_CPU_JIT");
|
||||
string compilation_device_name;
|
||||
|
||||
// 'transfer_as_literal' is true if device<->host transfers must be done
|
||||
// using XLA's TransferLiteral{To,From}Device interface. If false, we can
|
||||
// use ThenMemcpy instead.
|
||||
bool transfer_as_literal = false;
|
||||
|
||||
// If 'use_multiple_streams' is true, we create separate streams for
|
||||
// compute, host-to-device, and device-to-host communication.
|
||||
bool use_multiple_streams = false;
|
||||
@ -188,6 +183,7 @@ class XlaDevice : public LocalDevice {
|
||||
se::Platform* const platform_; // Not owned.
|
||||
// Memory allocator associated with this device.
|
||||
Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned.
|
||||
|
||||
// Stream associated with this device. Operations enqueued on this
|
||||
// stream are executed on the device. Operations include data
|
||||
// copying back and forth between CPU and the device, and
|
||||
@ -203,9 +199,11 @@ class XlaDevice : public LocalDevice {
|
||||
// If use_multiple_streams_, device to host transfers are performed using this
|
||||
// stream.
|
||||
std::shared_ptr<se::Stream> device_to_host_stream_ GUARDED_BY(mu_);
|
||||
// Must we use XLA's transfer manager for correct host<->device transfers? if
|
||||
// false, we can use ThenMemcpy() instead.
|
||||
const bool transfer_as_literal_;
|
||||
// If use_multiple_streams_, transfers between different devices are performed
|
||||
// using these streams.
|
||||
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_
|
||||
GUARDED_BY(mu_);
|
||||
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||
|
||||
// The device context accessed by all users of the XlaDevice, set by calls to
|
||||
|
@ -53,16 +53,17 @@ void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
|
||||
XlaDeviceContext::XlaDeviceContext(
|
||||
std::shared_ptr<se::Stream> compute_stream,
|
||||
std::shared_ptr<se::Stream> host_to_device_stream,
|
||||
std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client,
|
||||
bool transfer_as_literal,
|
||||
std::shared_ptr<se::Stream> device_to_host_stream,
|
||||
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,
|
||||
xla::LocalClient* client,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
thread::ThreadPool* thread_pool)
|
||||
: stream_(std::move(compute_stream)),
|
||||
host_to_device_stream_(std::move(host_to_device_stream)),
|
||||
device_to_host_stream_(std::move(device_to_host_stream)),
|
||||
device_to_device_streams_(std::move(device_to_device_streams)),
|
||||
client_(client),
|
||||
transfer_manager_(client->backend().transfer_manager()),
|
||||
transfer_as_literal_(transfer_as_literal),
|
||||
shape_representation_fn_(std::move(shape_representation_fn)),
|
||||
thread_pool_(thread_pool) {
|
||||
CHECK(host_to_device_stream_ != nullptr);
|
||||
@ -75,71 +76,6 @@ XlaDeviceContext::XlaDeviceContext(
|
||||
}
|
||||
}
|
||||
|
||||
Status XlaDeviceContext::TransferLiteralToDevice(const Tensor& host_tensor,
|
||||
Tensor* device_tensor) const {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
|
||||
host_tensor.shape(), &xla_shape));
|
||||
// Create a reference to hold onto host_tensor until after the literal has
|
||||
// been transferred. Also make sure the literal exists until the function
|
||||
// asynchronously completes, as it will be wrapped in an xla::LiteralSlice.
|
||||
TensorReference ref(host_tensor);
|
||||
auto literal = std::make_shared<xla::BorrowingLiteral>(
|
||||
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
|
||||
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
||||
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
|
||||
VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " "
|
||||
<< shaped_buffer.ToString();
|
||||
if (UseMultipleStreams() && !transfer_manager_->CanShapedBufferBeAccessedNow(
|
||||
stream_->parent(), shaped_buffer)) {
|
||||
// Initially wait for the compute stream so that memory allocations are
|
||||
// synchronized.
|
||||
host_to_device_stream_->ThenWaitFor(stream_.get());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
|
||||
host_to_device_stream_.get(), *literal, shaped_buffer));
|
||||
if (UseMultipleStreams()) {
|
||||
auto event = std::make_shared<se::Event>(stream_->parent());
|
||||
TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
|
||||
host_to_device_stream_->ThenRecordEvent(event.get());
|
||||
xla_tensor->ResetDefinitionEvent(std::move(event),
|
||||
host_to_device_stream_.get());
|
||||
}
|
||||
// Unref the host tensor, and capture the literal shared_ptr too so it goes
|
||||
// out of scope when the lambda completes.
|
||||
// We don't defer the call to done() onto the stream here, and the reasons why
|
||||
// this is correct are subtle. We assume that:
|
||||
// a) all consumers of the device tensor will wait for its definition event.
|
||||
// b) if the tensor is destroyed, then the memory allocator will not hand out
|
||||
// the same buffers until the transfer has completed.
|
||||
host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); });
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void XlaDeviceContext::TransferLiteralFromDevice(
|
||||
Tensor* host_tensor, const Tensor& device_tensor,
|
||||
const StatusCallback& done) const {
|
||||
xla::MutableBorrowingLiteral literal;
|
||||
TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(host_tensor, &literal));
|
||||
|
||||
const xla::ShapedBuffer& shaped_buffer =
|
||||
XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
|
||||
|
||||
TensorReference ref(device_tensor);
|
||||
transfer_manager_->TransferLiteralFromDevice(
|
||||
device_to_host_stream_.get(), shaped_buffer, literal,
|
||||
[=, &shaped_buffer](xla::Status status) {
|
||||
ref.Unref();
|
||||
done([&]() -> Status {
|
||||
VLOG(1) << "Transfer from device as literal: "
|
||||
<< shaped_buffer.ToString();
|
||||
return status;
|
||||
}());
|
||||
});
|
||||
}
|
||||
|
||||
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
Device* device,
|
||||
Tensor* device_tensor,
|
||||
@ -158,54 +94,73 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
<< cpu_tensor->shape().DebugString() << " "
|
||||
<< device_tensor->shape().DebugString();
|
||||
|
||||
void* src_ptr = const_cast<void*>(DMAHelper::base(cpu_tensor));
|
||||
const int64 total_bytes = cpu_tensor->TotalBytes();
|
||||
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
||||
CHECK(xla_tensor);
|
||||
|
||||
xla::StatusOr<TensorShape> shape_or_status =
|
||||
shape_representation_fn_(device_tensor->shape(), device_tensor->dtype());
|
||||
if (!shape_or_status.ok()) {
|
||||
done(shape_or_status.status());
|
||||
Status status = [&]() -> Status {
|
||||
TF_ASSIGN_OR_RETURN(TensorShape shape,
|
||||
shape_representation_fn_(device_tensor->shape(),
|
||||
device_tensor->dtype()));
|
||||
|
||||
// The device tensor should always be fresh.
|
||||
TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
|
||||
|
||||
xla_tensor->set_host_tensor(*cpu_tensor);
|
||||
TF_RETURN_IF_ERROR(
|
||||
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
|
||||
stream_->parent()->device_ordinal()));
|
||||
|
||||
xla::BorrowingLiteral literal(
|
||||
static_cast<const char*>(DMAHelper::base(cpu_tensor)),
|
||||
xla_tensor->shaped_buffer().on_host_shape());
|
||||
|
||||
VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
|
||||
<< xla_tensor->shaped_buffer().ToString();
|
||||
if (UseMultipleStreams() &&
|
||||
!transfer_manager_->CanShapedBufferBeAccessedNow(
|
||||
stream_->parent(), xla_tensor->shaped_buffer())) {
|
||||
// Initially wait for the compute stream so that memory allocations are
|
||||
// synchronized.
|
||||
host_to_device_stream_->ThenWaitFor(stream_.get());
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
|
||||
host_to_device_stream_.get(), literal, xla_tensor->shaped_buffer()));
|
||||
|
||||
if (UseMultipleStreams()) {
|
||||
auto event = std::make_shared<se::Event>(stream_->parent());
|
||||
TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
|
||||
host_to_device_stream_->ThenRecordEvent(event.get());
|
||||
xla_tensor->ResetDefinitionEvent(std::move(event),
|
||||
host_to_device_stream_.get());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}();
|
||||
if (!status.ok()) {
|
||||
done(status);
|
||||
return;
|
||||
}
|
||||
TensorShape shape = shape_or_status.ValueOrDie();
|
||||
if (!xla_tensor->has_shaped_buffer()) {
|
||||
Status s =
|
||||
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
|
||||
stream_->parent()->device_ordinal());
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
Status status;
|
||||
if (transfer_as_literal_) {
|
||||
Tensor reshaped_cpu_tensor;
|
||||
if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
|
||||
done(errors::Internal(
|
||||
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
|
||||
return;
|
||||
}
|
||||
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
|
||||
// Create a reference to hold onto cpu_tensor until after the literal has
|
||||
// been transferred
|
||||
TensorReference ref(*cpu_tensor);
|
||||
if (UseMultipleStreams()) {
|
||||
// Unref the host tensor when the transfer completes.
|
||||
// We don't defer the call to done() onto the stream here, and the reasons
|
||||
// why this is correct are subtle. We assume that:
|
||||
// a) all consumers of the device tensor will wait for its definition event.
|
||||
// b) if the tensor is destroyed, then the memory allocator will not hand
|
||||
// out the same buffers until the transfer has completed.
|
||||
host_to_device_stream_->ThenDoHostCallback([ref]() { ref.Unref(); });
|
||||
done(status);
|
||||
} else {
|
||||
se::DeviceMemoryBase dev_dst_ptr =
|
||||
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
||||
host_to_device_stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
|
||||
// TODO(hpucha): Make this asynchronous.
|
||||
Status block_status = host_to_device_stream_->BlockHostUntilDone();
|
||||
if (!block_status.ok()) {
|
||||
status = xla::InternalError(
|
||||
"Failed to complete data transfer on stream %p: %s",
|
||||
host_to_device_stream_.get(), block_status.error_message().c_str());
|
||||
}
|
||||
host_to_device_stream_->ThenDoHostCallback([ref, done]() {
|
||||
ref.Unref();
|
||||
done(Status::OK());
|
||||
});
|
||||
}
|
||||
if (status.ok()) {
|
||||
xla_tensor->set_host_tensor(*cpu_tensor);
|
||||
}
|
||||
done(status);
|
||||
}
|
||||
|
||||
void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
@ -225,30 +180,31 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
<< cpu_tensor->shape().DebugString() << " "
|
||||
<< device_tensor->shape().DebugString();
|
||||
|
||||
const int64 total_bytes = cpu_tensor->TotalBytes();
|
||||
se::DeviceMemoryBase dev_src_ptr =
|
||||
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
||||
void* dst_ptr = DMAHelper::base(cpu_tensor);
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
||||
|
||||
xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream_.get());
|
||||
|
||||
Status status;
|
||||
if (transfer_as_literal_) {
|
||||
TransferLiteralFromDevice(cpu_tensor, *device_tensor, done);
|
||||
return;
|
||||
} else {
|
||||
device_to_host_stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
|
||||
// TODO(hpucha): Make this asynchronous.
|
||||
Status block_status = device_to_host_stream_->BlockHostUntilDone();
|
||||
if (!block_status.ok()) {
|
||||
status = xla::InternalError(
|
||||
"Failed to complete data transfer on stream %p: %s", stream_.get(),
|
||||
block_status.error_message().c_str());
|
||||
}
|
||||
}
|
||||
xla::MutableBorrowingLiteral literal;
|
||||
TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(cpu_tensor, &literal));
|
||||
|
||||
done(status);
|
||||
TensorReference ref(*device_tensor);
|
||||
transfer_manager_->TransferLiteralFromDevice(
|
||||
device_to_host_stream_.get(), xla_tensor->shaped_buffer(), literal,
|
||||
[ref, xla_tensor, done](xla::Status status) {
|
||||
done([&]() -> Status {
|
||||
VLOG(1) << "Transfer from device as literal: "
|
||||
<< xla_tensor->shaped_buffer().ToString();
|
||||
return status;
|
||||
}());
|
||||
ref.Unref();
|
||||
});
|
||||
}
|
||||
|
||||
se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() {
|
||||
DCHECK_GT(device_to_device_streams_.size(), 0);
|
||||
absl::MutexLock lock(&mu_);
|
||||
int stream = next_stream_;
|
||||
next_stream_ = (next_stream_ + 1) % device_to_device_streams_.size();
|
||||
return device_to_device_stream(stream);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/jit/xla_tensor.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||
@ -50,7 +51,8 @@ class XlaDeviceContext : public DeviceContext {
|
||||
std::shared_ptr<se::Stream> compute_stream,
|
||||
std::shared_ptr<se::Stream> host_to_device_stream,
|
||||
std::shared_ptr<se::Stream> device_to_host_stream,
|
||||
xla::LocalClient* client, bool transfer_as_literal,
|
||||
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,
|
||||
xla::LocalClient* client,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
thread::ThreadPool* thread_pool);
|
||||
|
||||
@ -61,14 +63,26 @@ class XlaDeviceContext : public DeviceContext {
|
||||
absl::string_view tensor_name, Device* device,
|
||||
Tensor* cpu_tensor, StatusCallback done) override;
|
||||
|
||||
xla::LocalClient* client() const { return client_; }
|
||||
se::Stream* stream() const { return stream_.get(); }
|
||||
se::Stream* host_to_device_stream() const {
|
||||
return host_to_device_stream_.get();
|
||||
}
|
||||
se::Stream* device_to_host_stream() const {
|
||||
return device_to_host_stream_.get();
|
||||
}
|
||||
se::Stream* device_to_device_stream(int index) const {
|
||||
return device_to_device_streams_.at(index).get();
|
||||
}
|
||||
xla::TransferManager* transfer_manager() const { return transfer_manager_; }
|
||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
|
||||
return shape_representation_fn_;
|
||||
}
|
||||
|
||||
// Returns a device-to-device stream, in round-robin fashion.
|
||||
se::Stream* GetDeviceToDeviceStream();
|
||||
|
||||
private:
|
||||
Status TransferLiteralToDevice(const Tensor& host_tensor,
|
||||
Tensor* device_tensor) const;
|
||||
void TransferLiteralFromDevice(Tensor* host_tensor,
|
||||
const Tensor& device_tensor,
|
||||
const StatusCallback& done) const;
|
||||
bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; }
|
||||
|
||||
// The main compute stream of the device, used to synchronize the transfer
|
||||
@ -80,16 +94,22 @@ class XlaDeviceContext : public DeviceContext {
|
||||
// The stream to use for transferring data from device to host. Can be
|
||||
// idential to stream_, but must not be nullptr.
|
||||
std::shared_ptr<se::Stream> device_to_host_stream_;
|
||||
// Streams to use for transferring data directly between different devices,
|
||||
// e.g., over NVLINK.
|
||||
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_;
|
||||
|
||||
// For the underlying memory allocator and XLA's TransferManager.
|
||||
xla::LocalClient* client_;
|
||||
// Transfer manager, for marshalling data to and from the device.
|
||||
xla::TransferManager* transfer_manager_;
|
||||
// True if we must use XLA's TransferManager for correct device transfers.
|
||||
const bool transfer_as_literal_;
|
||||
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||
|
||||
// Thread pool used for running closures
|
||||
thread::ThreadPool* thread_pool_;
|
||||
|
||||
absl::Mutex mu_;
|
||||
int next_stream_ GUARDED_BY(mu_) = 0;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/resource_variable_ops.h"
|
||||
#include "tensorflow/core/kernels/sendrecv_ops.h"
|
||||
#include "tensorflow/core/kernels/shape_ops.h"
|
||||
#include "tensorflow/core/kernels/stack.h"
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -257,9 +258,27 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint<string>("T") \
|
||||
.HostMemory("input"), \
|
||||
RetvalOp);
|
||||
RetvalOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name("StackV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("max_size") \
|
||||
.HostMemory("handle"), \
|
||||
StackOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPushV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("handle") \
|
||||
.TypeConstraint("T", TYPES), \
|
||||
TemplatedStackPushOp</*allow_swapping=*/false>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPopV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("handle") \
|
||||
.TypeConstraint("elem_type", TYPES), \
|
||||
StackPopOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StackCloseV2").Device(DEVICE).HostMemory("handle"), StackCloseOp);
|
||||
|
||||
// TODO(phawkins): currently we do not register the QueueEnqueueMany,
|
||||
// TODO(b/118881356): currently we do not register the QueueEnqueueMany,
|
||||
// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
|
||||
// and write the tensors they access in order to concatenate them into a batch.
|
||||
// We would need either to call out to an XLA computation to perform the
|
||||
|
@ -37,8 +37,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
||||
std::vector<Device*>* devices) {
|
||||
XlaOpRegistry::DeviceRegistration registration;
|
||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||
registration.requires_compilation = true;
|
||||
registration.enable_jit_by_default = false;
|
||||
registration.autoclustering_policy =
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
registration.compile_resource_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
|
||||
|
||||
@ -59,7 +59,6 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
||||
options.device_name = DEVICE_XLA_GPU;
|
||||
options.device_ordinal = 0;
|
||||
options.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||
options.transfer_as_literal = false;
|
||||
options.use_multiple_streams = false;
|
||||
auto device = absl::make_unique<XlaDevice>(session_options, options);
|
||||
|
||||
|
@ -45,8 +45,8 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
|
||||
|
||||
XlaOpRegistry::DeviceRegistration registration;
|
||||
registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
|
||||
registration.requires_compilation = true;
|
||||
registration.enable_jit_by_default = false;
|
||||
registration.autoclustering_policy =
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
registration.compile_resource_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
|
||||
registration);
|
||||
@ -60,7 +60,6 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
|
||||
options.device_name = DEVICE_XLA_INTERPRETER;
|
||||
options.device_ordinal = 0;
|
||||
options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
|
||||
options.transfer_as_literal = false;
|
||||
options.use_multiple_streams = false;
|
||||
auto device = absl::make_unique<XlaDevice>(session_options, options);
|
||||
devices->push_back(device.release());
|
||||
|
@ -375,6 +375,27 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "resampler_ops_test",
|
||||
size = "small",
|
||||
srcs = ["resampler_ops_test.py"],
|
||||
disabled_backends = [
|
||||
# TODO(b/74459949) Support BatchDot in CPU backend.
|
||||
"cpu",
|
||||
"cpu_ondemand",
|
||||
],
|
||||
# TODO(b/112295522): figure out how to make OSS build pass.
|
||||
tags = ["no_oss"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/contrib/resampler:resampler_ops",
|
||||
"//tensorflow/contrib/resampler:resampler_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "dynamic_stitch_test",
|
||||
size = "small",
|
||||
@ -489,8 +510,6 @@ tf_xla_py_test(
|
||||
name = "function_test",
|
||||
size = "small",
|
||||
srcs = ["function_test.py"],
|
||||
# Functions are not implemented in the on-demand compilation model yet.
|
||||
disabled_backends = "cpu_ondemand",
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -680,9 +699,6 @@ tf_xla_py_test(
|
||||
name = "random_ops_test",
|
||||
size = "small",
|
||||
srcs = ["random_ops_test.py"],
|
||||
disabled_backends = [
|
||||
"cpu_ondemand",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -698,6 +714,10 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
srcs = ["reduce_ops_test.py"],
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
# TODO(b/119059212): Re-enable this test in OSS.
|
||||
"no_oss",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -713,7 +733,6 @@ tf_xla_py_test(
|
||||
name = "reduce_window_test",
|
||||
size = "small",
|
||||
srcs = ["reduce_window_test.py"],
|
||||
disabled_backends = ["cpu_ondemand"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
@ -822,8 +841,6 @@ tf_xla_py_test(
|
||||
name = "stack_ops_test",
|
||||
size = "small",
|
||||
srcs = ["stack_ops_test.py"],
|
||||
# Stack ops are not implemented in the on-demand compilation model yet.
|
||||
disabled_backends = "cpu_ondemand",
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -851,7 +868,7 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["tensor_array_ops_test.py"],
|
||||
# TensorArray ops are not implemented in the on-demand compilation model yet.
|
||||
disabled_backends = "cpu_ondemand",
|
||||
disabled_backends = ["cpu_ondemand"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -872,7 +889,7 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
srcs = ["tensor_list_ops_test.py"],
|
||||
# TensorList ops are not implemented in the on-demand compilation model yet.
|
||||
disabled_backends = "cpu_ondemand",
|
||||
disabled_backends = ["cpu_ondemand"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -952,7 +969,6 @@ tf_xla_py_test(
|
||||
name = "while_test",
|
||||
size = "small",
|
||||
srcs = ["while_test.py"],
|
||||
disabled_backends = ["cpu_ondemand"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
@ -1109,6 +1125,7 @@ cc_library(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
@ -1219,7 +1236,6 @@ tf_xla_py_test(
|
||||
name = "xla_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["xla_ops_test.py"],
|
||||
disabled_backends = ["cpu_ondemand"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
|
@ -178,6 +178,13 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
[0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype),
|
||||
expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_nn_ops.leaky_relu_grad,
|
||||
np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype),
|
||||
np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype),
|
||||
expected=np.array([0.2, 0.4, 0.6, 0.8, 1, 6, 7, 8, 9, 10],
|
||||
dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_nn_ops.softmax_cross_entropy_with_logits,
|
||||
np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype),
|
||||
|
@ -50,6 +50,8 @@ def tf_xla_py_test(
|
||||
"""
|
||||
if disabled_backends == None:
|
||||
disabled_backends = []
|
||||
if type(disabled_backends) != "list":
|
||||
fail("disabled_backends must be a list of strings", "disabled_backends")
|
||||
|
||||
enabled_backends = [b for b in all_backends() if b not in disabled_backends]
|
||||
test_names = []
|
||||
|
@ -135,7 +135,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-2.60260963, -4.29698515]),
|
||||
var0.eval(),
|
||||
float_rtol=1e-5,
|
||||
float_rtol=1e-4,
|
||||
half_rtol=1e-2)
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.28432083, -0.56694895]),
|
||||
@ -167,7 +167,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5)
|
||||
np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5,
|
||||
float_rtol=1e-4)
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5)
|
||||
|
||||
|
@ -448,8 +448,8 @@ class ResizeBilinearTest(xla_test.XLATestCase):
|
||||
for dtype in self.float_types:
|
||||
self._assertForwardOpMatchesExpected(
|
||||
np.array([[1, 2]], dtype=dtype), [3, 3],
|
||||
expected=np.array(
|
||||
[[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]], dtype=np.float32))
|
||||
expected=np.array([[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]],
|
||||
dtype=np.float32))
|
||||
|
||||
def testAlignCorners1x2To3x2Grad(self):
|
||||
for dtype in self.float_types:
|
||||
@ -477,8 +477,8 @@ class ResizeBilinearTest(xla_test.XLATestCase):
|
||||
for dtype in self.float_types:
|
||||
self._assertForwardOpMatchesExpected(
|
||||
np.array([[1, 2], [3, 4]], dtype=dtype), [3, 3],
|
||||
expected=np.array(
|
||||
[[1, 1.5, 2], [2, 2.5, 3], [3, 3.5, 4]], dtype=np.float32))
|
||||
expected=np.array([[1, 1.5, 2], [2, 2.5, 3], [3, 3.5, 4]],
|
||||
dtype=np.float32))
|
||||
|
||||
def testAlignCorners2x2To3x3Grad(self):
|
||||
self._assertBackwardOpMatchesExpected(
|
||||
@ -498,8 +498,8 @@ class ResizeBilinearTest(xla_test.XLATestCase):
|
||||
np.array([[7, 13], [22, 4]], dtype=np.float32),
|
||||
input_shape=[3, 3],
|
||||
dtype=dtype,
|
||||
expected=np.array(
|
||||
[[7, 0, 13], [0, 0, 0], [22, 0, 4]], dtype=np.float32))
|
||||
expected=np.array([[7, 0, 13], [0, 0, 0], [22, 0, 4]],
|
||||
dtype=np.float32))
|
||||
|
||||
def testAlignCorners4x4To3x3(self):
|
||||
for dtype in self.float_types:
|
||||
@ -507,8 +507,8 @@ class ResizeBilinearTest(xla_test.XLATestCase):
|
||||
np.array(
|
||||
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
|
||||
dtype=dtype), [3, 3],
|
||||
expected=np.array(
|
||||
[[1, 2.5, 4], [7, 8.5, 10], [13, 14.5, 16]], dtype=np.float32))
|
||||
expected=np.array([[1, 2.5, 4], [7, 8.5, 10], [13, 14.5, 16]],
|
||||
dtype=np.float32))
|
||||
|
||||
def testAlignCorners4x4To3x3Grad(self):
|
||||
for dtype in self.float_types:
|
||||
@ -516,41 +516,39 @@ class ResizeBilinearTest(xla_test.XLATestCase):
|
||||
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32),
|
||||
input_shape=[4, 4],
|
||||
dtype=dtype,
|
||||
expected=np.array(
|
||||
[[1, 1, 1, 3], [2, 1.25, 1.25, 3], [2, 1.25, 1.25, 3],
|
||||
[7, 4, 4, 9]],
|
||||
dtype=np.float32))
|
||||
expected=np.array([[1, 1, 1, 3], [2, 1.25, 1.25, 3],
|
||||
[2, 1.25, 1.25, 3], [7, 4, 4, 9]],
|
||||
dtype=np.float32))
|
||||
|
||||
def testAlignCorners3x3To9x9(self):
|
||||
for dtype in self.float_types:
|
||||
self._assertForwardOpMatchesExpected(
|
||||
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), [9, 9],
|
||||
expected=np.array(
|
||||
[[1.0, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], [
|
||||
1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75
|
||||
], [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], [
|
||||
3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25
|
||||
], [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], [
|
||||
4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75
|
||||
], [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], [
|
||||
6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25
|
||||
], [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]],
|
||||
[[1.0, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00],
|
||||
[1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75],
|
||||
[2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50],
|
||||
[3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25],
|
||||
[4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00],
|
||||
[4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75],
|
||||
[5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50],
|
||||
[6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25],
|
||||
[7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]],
|
||||
dtype=np.float32))
|
||||
|
||||
def testAlignCorners3x3To9x9Grad(self):
|
||||
for dtype in self.float_types:
|
||||
self._assertBackwardOpMatchesExpected(
|
||||
np.array(
|
||||
[[1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], [
|
||||
1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75
|
||||
], [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], [
|
||||
3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25
|
||||
], [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], [
|
||||
4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75
|
||||
], [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], [
|
||||
6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25
|
||||
], [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]],
|
||||
dtype=np.float32),
|
||||
np.array([[1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00],
|
||||
[1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75],
|
||||
[2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50],
|
||||
[3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25],
|
||||
[4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00],
|
||||
[4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75],
|
||||
[5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50],
|
||||
[6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25],
|
||||
[7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]],
|
||||
dtype=np.float32),
|
||||
input_shape=[3, 3],
|
||||
dtype=dtype,
|
||||
expected=np.array(
|
||||
@ -571,12 +569,12 @@ class ResizeBilinearTest(xla_test.XLATestCase):
|
||||
(np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array(
|
||||
[[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)) * 15.0,
|
||||
[16, 16],
|
||||
expected=7 * (np.array(
|
||||
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],
|
||||
dtype=np.float32) + np.array(
|
||||
[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11],
|
||||
[12], [13], [14], [15]],
|
||||
dtype=np.float32)),
|
||||
expected=7 *
|
||||
(np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],
|
||||
dtype=np.float32) +
|
||||
np.array([[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11],
|
||||
[12], [13], [14], [15]],
|
||||
dtype=np.float32)),
|
||||
large_tolerance=True)
|
||||
|
||||
def testNonAlignCorners3x2To6x4(self):
|
||||
@ -600,6 +598,26 @@ class ResizeBilinearTest(xla_test.XLATestCase):
|
||||
expected=np.array(expected_data, dtype=dtype),
|
||||
align_corners=False)
|
||||
|
||||
def testNonAlignCorners3x2To6x4Batch2(self):
|
||||
input_data = [[[64, 32], [32, 64], [50, 100]], [[32, 16], [16, 32],
|
||||
[25, 50]]]
|
||||
expected_data = [[[64.0, 48.0, 32.0, 32.0], [48.0, 48.0, 48.0, 48.0],
|
||||
[32.0, 48.0, 64.0, 64.0], [41.0, 61.5, 82.0, 82.0],
|
||||
[50.0, 75.0, 100.0, 100.0], [50.0, 75.0, 100.0, 100.0]],
|
||||
[[32.0, 24.0, 16.0, 16.0], [24.0, 24.0, 24.0, 24.0],
|
||||
[16.0, 24.0, 32.0, 32.0], [20.5, 30.75, 41.0, 41.0],
|
||||
[25.0, 37.5, 50.0, 50.0], [25.0, 37.5, 50.0, 50.0]]]
|
||||
|
||||
for dtype in self.float_types:
|
||||
input_image = np.array(input_data, dtype=dtype)
|
||||
expected = np.array(expected_data, dtype=dtype)
|
||||
with self.cached_session() as sess, self.test_scope():
|
||||
image = array_ops.placeholder(input_image.dtype)
|
||||
resized = gen_image_ops.resize_bilinear(
|
||||
image, [6, 4], align_corners=False)
|
||||
out = sess.run(resized, {image: input_image[:, :, :, np.newaxis]})
|
||||
self.assertAllClose(expected[:, :, :, np.newaxis], out)
|
||||
|
||||
|
||||
class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||
|
||||
@ -804,5 +822,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||
self.assertEqual(num_valid, 3)
|
||||
self.assertAllClose(indices_tf[:num_valid], [0, 2, 4])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -45,6 +45,7 @@ limitations under the License.
|
||||
#include <random>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
@ -2687,6 +2688,37 @@ TEST_F(OpTest, Reverse) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, ReverseSequence) {
|
||||
Repeatedly([this]() {
|
||||
std::vector<int64> dims = RandomDims(/*min_rank=*/2);
|
||||
auto type = Choose<DataType>(kAllXlaTypes);
|
||||
int64 rank = dims.size();
|
||||
|
||||
// Choose random batch and sequence dimensions.
|
||||
std::vector<int> shuffled_dim_ids(rank);
|
||||
absl::c_iota(shuffled_dim_ids, 0);
|
||||
absl::c_shuffle(shuffled_dim_ids, generator());
|
||||
shuffled_dim_ids.resize(2);
|
||||
int batch_dim = shuffled_dim_ids[0];
|
||||
int seq_dim = shuffled_dim_ids[1];
|
||||
|
||||
int batch_size = dims[batch_dim];
|
||||
int max_seq_len = dims[seq_dim];
|
||||
std::vector<int32> seq_lens(batch_size);
|
||||
std::uniform_int_distribution<int32> d(0, max_seq_len);
|
||||
absl::c_generate(seq_lens, [&]() { return d(generator()); });
|
||||
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("ReverseSequence")
|
||||
.RandomInput(type, dims)
|
||||
.Input(test::AsTensor<int32>(seq_lens))
|
||||
.Attr("seq_dim", seq_dim)
|
||||
.Attr("batch_dim", batch_dim)
|
||||
.Attr("T", type)
|
||||
.Attr("Tlen", DT_INT32));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, ReverseV2) {
|
||||
Repeatedly([this]() {
|
||||
auto type = Choose<DataType>(kAllXlaTypes);
|
||||
|
156
tensorflow/compiler/tests/resampler_ops_test.py
Normal file
156
tensorflow/compiler/tests/resampler_ops_test.py
Normal file
@ -0,0 +1,156 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for resampler ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.contrib import resampler
|
||||
from tensorflow.contrib.resampler.ops import gen_resampler_ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ResamplerOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def _assertForwardOpMatchesExpected(self, image_np, warp_np, expected):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
input_image = array_ops.placeholder(image_np.dtype)
|
||||
warp = array_ops.placeholder(warp_np.dtype)
|
||||
resampled = resampler.resampler(input_image, warp, name='resampler')
|
||||
out = sess.run(resampled, {input_image: image_np, warp: warp_np})
|
||||
|
||||
self.assertAllCloseAccordingToType(
|
||||
expected, out, half_rtol=1e-2, bfloat16_rtol=3e-2)
|
||||
|
||||
def _assertBackwardOpMatchesExpected(self, input_np, warp_np, grad_output_np,
|
||||
expected_grad_data, expected_grad_warp):
|
||||
with self.cached_session() as sess, self.test_scope():
|
||||
input_image = array_ops.placeholder(input_np.dtype)
|
||||
warp = array_ops.placeholder(warp_np.dtype)
|
||||
grad_output = array_ops.placeholder(grad_output_np.dtype)
|
||||
|
||||
grad_data, grad_warp = gen_resampler_ops.resampler_grad(
|
||||
input_image, warp, grad_output)
|
||||
|
||||
grad_data_tf, grad_warp_tf = sess.run([grad_data, grad_warp], {
|
||||
input_image: input_np,
|
||||
warp: warp_np,
|
||||
grad_output: grad_output_np
|
||||
})
|
||||
|
||||
self.assertAllCloseAccordingToType(
|
||||
expected_grad_warp, grad_warp_tf, half_rtol=1e-2, bfloat16_rtol=3e-2)
|
||||
self.assertAllCloseAccordingToType(
|
||||
expected_grad_data, grad_data_tf, half_rtol=1e-2, bfloat16_rtol=3e-2)
|
||||
|
||||
def testSimple(self):
|
||||
for dtype in self.float_types:
|
||||
input_shape = [1, 2, 2, 1]
|
||||
input_rgb_data = [0, 5, 13, 54]
|
||||
input_np = np.array(input_rgb_data, dtype=dtype).reshape(input_shape)
|
||||
|
||||
warp_shape = [1, 2]
|
||||
warp_data = [0.7, 0.6]
|
||||
warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape)
|
||||
expected = [[26.42]]
|
||||
self._assertForwardOpMatchesExpected(input_np, warp_np, expected)
|
||||
|
||||
grad_output = np.ones([1, 1], dtype=dtype)
|
||||
|
||||
expected_grad_data = [[[[0.12], [0.27999997]], [[0.18000001],
|
||||
[0.42000002]]]]
|
||||
|
||||
expected_grad_warp = [[26.60000038, 38.20000076]]
|
||||
|
||||
self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output,
|
||||
expected_grad_data,
|
||||
expected_grad_warp)
|
||||
|
||||
def testMultiChannel(self):
|
||||
for dtype in self.float_types:
|
||||
input_shape = [1, 2, 2, 3]
|
||||
input_rgb_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
|
||||
input_np = np.array(input_rgb_data, dtype=dtype).reshape(input_shape)
|
||||
|
||||
warp_shape = [1, 2]
|
||||
warp_data = [0.7, 0.6]
|
||||
warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape)
|
||||
expected = [[59.58000183, 146.94000244, 107.37999725]]
|
||||
self._assertForwardOpMatchesExpected(input_np, warp_np, expected)
|
||||
|
||||
grad_output = np.ones([1, 3], dtype=dtype)
|
||||
|
||||
expected_grad_data = [[[[0.12, 0.12, 0.12],
|
||||
[0.27999997, 0.27999997, 0.27999997]],
|
||||
[[0.18000001, 0.18000001, 0.18000001],
|
||||
[0.42000002, 0.42000002, 0.42000002]]]]
|
||||
|
||||
expected_grad_warp = [[199, 30]]
|
||||
|
||||
self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output,
|
||||
expected_grad_data,
|
||||
expected_grad_warp)
|
||||
|
||||
def testBatch2Height3byWidth3RGB(self):
|
||||
for dtype in self.float_types:
|
||||
input_shape = [2, 3, 3, 3]
|
||||
input_rgb_data = [
|
||||
0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1, 30, 105, 2, 40, 115,
|
||||
3, 50, 125, 4, 60, 135, 5, 70, 145, 6, 0, 5, 13, 54, 135, 226, 37, 8,
|
||||
234, 90, 255, 1, 30, 105, 2, 40, 115, 3, 50, 125, 4, 60, 135, 5, 70,
|
||||
145, 6
|
||||
]
|
||||
input_np = np.array(input_rgb_data, dtype=dtype).reshape(input_shape)
|
||||
|
||||
# 2 batches and 2 samples for each batch.
|
||||
warp_shape = [2, 2, 2]
|
||||
warp_data = [0.7, 0.6, 1, 0.7, 0.9, 1.2, 1.3, 1.6]
|
||||
warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape)
|
||||
|
||||
expected_forward = [[[43.92, 128.4, 65.86], [37.2, 114., 69.2]],
|
||||
[[40.6, 122.8, 2.5], [51., 126, 4.1]]]
|
||||
|
||||
self._assertForwardOpMatchesExpected(input_np, warp_np, expected_forward)
|
||||
|
||||
expected_grad_data = [[[[0.12, 0.12, 0.12],
|
||||
[0.57999998, 0.57999998, 0.57999998],
|
||||
[0., 0., 0.]],
|
||||
[[0.18000001, 0.18000001, 0.18000001],
|
||||
[1.12, 1.12, 1.12], [0., 0., 0.]],
|
||||
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
|
||||
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
|
||||
[[0.08000001, 0.08000001, 0.08000001],
|
||||
[0.99999988, 0.99999988, 0.99999988],
|
||||
[0.11999997, 0.11999997, 0.11999997]],
|
||||
[[0.02000001, 0.02000001, 0.02000001],
|
||||
[0.60000008, 0.60000008, 0.60000008],
|
||||
[0.17999998, 0.17999998, 0.17999998]]]]
|
||||
expected_grad_warp = [[[33.39999008, -96.20000458], [-26.10000229,
|
||||
-278.]],
|
||||
[[-162.99998474, 39.99999619], [21., 63.]]]
|
||||
|
||||
grad_output = np.ones([2, 2, 3], dtype=dtype)
|
||||
self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output,
|
||||
expected_grad_data,
|
||||
expected_grad_warp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -358,6 +358,11 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([[-0.05, 6.05, 5]], dtype=dtype),
|
||||
expected=np.array([[0, 6, 5]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.leaky_relu,
|
||||
np.array([[-2, -1, 0, 1, 2]], dtype=dtype),
|
||||
expected=np.array([[-0.4, -0.2, 0.0, 1.0, 2.0]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.softmax,
|
||||
np.array([1, 2, 3, 4], dtype=dtype),
|
||||
|
@ -194,6 +194,7 @@ cc_library(
|
||||
":side_effect_util",
|
||||
":tf2xla_util",
|
||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
@ -178,6 +178,32 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
# A separate cc_library for resampler_ops is needed because resampler is in
|
||||
# contrib/, and thus the declaration of resampler cannot be pulled into the deps
|
||||
# of xla_ops. Therefore, resampler_ops is its own cc_library target, and its
|
||||
# corresponding tf_kernel_library is defined in contrib/resampler/BUILD.
|
||||
cc_library(
|
||||
name = "resampler_ops",
|
||||
srcs = ["resampler_ops.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/client/lib:numeric",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "conv_op_helpers",
|
||||
srcs = ["conv_op_helpers.cc"],
|
||||
|
@ -231,20 +231,22 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
|
||||
num_extended[0] = upper_padding[0] / (dims.kernel_size[0]);
|
||||
num_extended[1] = upper_padding[1] / (dims.kernel_size[1]);
|
||||
|
||||
const int64 batch_dim_size =
|
||||
builder->GetShape(input).ValueOrDie().dimensions(0);
|
||||
if (num_extended[0] > 0) {
|
||||
auto slice =
|
||||
xla::Slice(input_data, {0, in_size[0] - 1, 0, 0},
|
||||
{1, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
|
||||
auto slice = xla::Slice(
|
||||
input_data, {0, in_size[0] - 1, 0, 0},
|
||||
{batch_dim_size, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
|
||||
for (int i = 0; i < num_extended[0]; i++) {
|
||||
input_data = xla::ConcatInDim(builder, {input_data, slice}, 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (num_extended[1] > 0) {
|
||||
auto slice =
|
||||
xla::Slice(input_data, {0, 0, in_size[1] - 1, 0},
|
||||
{1, in_size[0] + num_extended[0], in_size[1], channels},
|
||||
{1, 1, 1, 1});
|
||||
auto slice = xla::Slice(
|
||||
input_data, {0, 0, in_size[1] - 1, 0},
|
||||
{batch_dim_size, in_size[0] + num_extended[0], in_size[1], channels},
|
||||
{1, 1, 1, 1});
|
||||
for (int i = 0; i < num_extended[1]; i++) {
|
||||
input_data = xla::ConcatInDim(builder, {input_data, slice}, 2);
|
||||
}
|
||||
|
@ -15,14 +15,12 @@ limitations under the License.
|
||||
|
||||
// Native XLA implementations of XLA Relu Ops
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/no_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -37,6 +35,7 @@ class ReluOp : public XlaOpKernel {
|
||||
ctx->SetOutput(0, xla::Max(zero, ctx->Input(0)));
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("Relu"), ReluOp);
|
||||
|
||||
class Relu6Op : public XlaOpKernel {
|
||||
public:
|
||||
@ -49,6 +48,22 @@ class Relu6Op : public XlaOpKernel {
|
||||
ctx->SetOutput(0, xla::Clamp(zero, ctx->Input(0), six));
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("Relu6"), Relu6Op);
|
||||
|
||||
class LeakyReluOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit LeakyReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_));
|
||||
}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto features = ctx->Input("features");
|
||||
auto output =
|
||||
xla::Max(features, features * xla::ScalarLike(features, alpha_));
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
float alpha_;
|
||||
};
|
||||
REGISTER_XLA_OP(Name("LeakyRelu"), LeakyReluOp);
|
||||
|
||||
class ReluGradOp : public XlaOpKernel {
|
||||
public:
|
||||
@ -64,6 +79,7 @@ class ReluGradOp : public XlaOpKernel {
|
||||
ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), zero));
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ReluGrad"), ReluGradOp);
|
||||
|
||||
class Relu6GradOp : public XlaOpKernel {
|
||||
public:
|
||||
@ -83,11 +99,24 @@ class Relu6GradOp : public XlaOpKernel {
|
||||
ctx->SetOutput(0, out);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("Relu"), ReluOp);
|
||||
REGISTER_XLA_OP(Name("Relu6"), Relu6Op);
|
||||
REGISTER_XLA_OP(Name("ReluGrad"), ReluGradOp);
|
||||
REGISTER_XLA_OP(Name("Relu6Grad"), Relu6GradOp);
|
||||
|
||||
class LeakyReluGradOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit LeakyReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_));
|
||||
}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto gradients = ctx->Input("gradients");
|
||||
auto features = ctx->Input("features");
|
||||
auto output =
|
||||
xla::Select(xla::Gt(features, xla::ScalarLike(features, 0)), gradients,
|
||||
gradients * xla::ScalarLike(gradients, alpha_));
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
float alpha_;
|
||||
};
|
||||
REGISTER_XLA_OP(Name("LeakyReluGrad"), LeakyReluGradOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
541
tensorflow/compiler/tf2xla/kernels/resampler_ops.cc
Normal file
541
tensorflow/compiler/tf2xla/kernels/resampler_ops.cc
Normal file
@ -0,0 +1,541 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/array4d.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/numeric.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using xla::XlaOp;
|
||||
|
||||
// TODO(b/112295522): note that sampling from image boundary is not currently
|
||||
// being handled properly.
|
||||
|
||||
// Calculates the bilinear weight tensor, given basis ratio (px, py) of the
|
||||
// sampling position:
|
||||
// W = [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py]
|
||||
// 'ratio' tensor has dimensions [batch, dim_0, ...dim_n, 2].
|
||||
//
|
||||
// The returned tensor has dimensions [batch, dim_0, ... dim_n, 4].
|
||||
XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio,
|
||||
const TensorShape warp_shape,
|
||||
xla::PrimitiveType xla_type) {
|
||||
auto first_term = xla::ConstantR2<float>(
|
||||
ctx->builder(), {{1.0, 1.0}, {0.0, 1.0}, {1.0, 0.0}, {0.0, 0.0}});
|
||||
first_term = xla::ConvertElementType(first_term, xla_type);
|
||||
|
||||
auto warp_dims = warp_shape.dim_sizes();
|
||||
std::vector<int64> broadcast_dims(warp_dims.begin(), warp_dims.end() - 1);
|
||||
broadcast_dims.push_back(4);
|
||||
broadcast_dims.push_back(2);
|
||||
|
||||
const int64 broadcast_dims_size = broadcast_dims.size();
|
||||
|
||||
std::vector<int64> last_two_dims_indices = {(broadcast_dims_size - 2),
|
||||
(broadcast_dims_size - 1)};
|
||||
|
||||
xla::Shape broadcast_shape =
|
||||
xla::ShapeUtil::MakeShape(xla_type, broadcast_dims);
|
||||
|
||||
auto broadcast_first_term =
|
||||
xla::BroadcastInDim(first_term, broadcast_shape, last_two_dims_indices);
|
||||
|
||||
// Ratio is of the same dimension as warp, which is [batch, dim_0,... dim_n,
|
||||
// 2], we broadcast ratio tensor to 'broadcast_dim' by keeping the
|
||||
// [batch, dim_0,...dim_n] dimensions and the [2] dimension as the last
|
||||
// dimension.
|
||||
std::vector<int64> ratio_broadcast_indices(broadcast_dims.size());
|
||||
std::iota(ratio_broadcast_indices.begin(), ratio_broadcast_indices.end(), 0);
|
||||
ratio_broadcast_indices.erase(ratio_broadcast_indices.end() - 2);
|
||||
|
||||
auto broadcast_ratio =
|
||||
xla::BroadcastInDim(ratio, broadcast_shape, ratio_broadcast_indices);
|
||||
|
||||
auto first_term_subtract_weights = broadcast_first_term - broadcast_ratio;
|
||||
|
||||
// Now we have [(1-px, 1-py), (-px, 1-py), (1-px, -py), (px, py)], need to
|
||||
// flip the signs of the second and the third term.
|
||||
auto sign_change = xla::ConstantR2<float>(
|
||||
ctx->builder(), {{1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {1.0, 1.0}});
|
||||
sign_change = xla::ConvertElementType(sign_change, xla_type);
|
||||
|
||||
auto broadcast_sign_change =
|
||||
xla::BroadcastInDim(sign_change, broadcast_shape, last_two_dims_indices);
|
||||
|
||||
auto flipped = first_term_subtract_weights * broadcast_sign_change;
|
||||
|
||||
// Build up the final bilinear weight tensor by multiply reduction, which
|
||||
// gives:
|
||||
// [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py]
|
||||
// for each 4 neighboring pixels where px and py are the weight of the target
|
||||
// pixel we are sampling from.
|
||||
return xla::Reduce(
|
||||
flipped, xla::One(ctx->builder(), xla_type),
|
||||
xla::CreateScalarMultiplyComputation(xla_type, ctx->builder()),
|
||||
{broadcast_dims_size - 1});
|
||||
}
|
||||
|
||||
// Concatenates the batch indices to the (x, y) coordinate indices.
|
||||
// This is done by first creating an Iota tensor that represents the current
|
||||
// batch it is in, then concatenate with the givin (coordinate) indices.
|
||||
//
|
||||
// The resulting tensor has dimension (batch, dim_0, ... dim_n, 3) where
|
||||
// the last dimension of size 3 in turn is [batch_number, x, y].
|
||||
// The [batch_number, x, y] dimension is needed because the indices
|
||||
// [x,y] alone cannot allow the xla::Gather operation to gather from the input
|
||||
// data, which is of dimension [batch, height(y), width(x), channel] with
|
||||
// 'batch' being the first dimension.
|
||||
XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices,
|
||||
const TensorShape& warp_shape) {
|
||||
// We need to create an iota tensor with the same batch dimension.
|
||||
std::vector<int64> dimensions;
|
||||
for (auto dim : warp_shape) {
|
||||
dimensions.push_back(dim.size);
|
||||
}
|
||||
// Except the last dimension, which is of size 1.
|
||||
dimensions.back() = 1;
|
||||
|
||||
auto batch_indices =
|
||||
xla::Iota(b, xla::ShapeUtil::MakeShape(xla::U32, dimensions),
|
||||
/*iota_dimension=*/0);
|
||||
|
||||
return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1);
|
||||
}
|
||||
|
||||
// Gathers the 2x2 neighbors of the input starting_indices, and return a
|
||||
// tensor of dimension [batch, dim_0, ... dim_n, 4, data_channels].
|
||||
// 'gather_indices' is of dimension [batch, dim_0, ..., dim_n, 3] where the last
|
||||
// dimension of size 3 is (batch_no, x, y).
|
||||
XlaOp Gather2by2Neighbors(xla::XlaBuilder* b, XlaOp data, XlaOp gather_indices,
|
||||
int64 data_channels, int warp_dims) {
|
||||
xla::GatherDimensionNumbers gather_dim_numbers;
|
||||
const int64 neighbor_data_dimensions = warp_dims + 2;
|
||||
// Since the Gather output dimensions are [batch, dim_0, ... dim_n, 2, 2,
|
||||
// data_channels], the offset dimensions for Gather is the last 3 dimensions.
|
||||
gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 3);
|
||||
gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 2);
|
||||
gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 1);
|
||||
// The last dimension of 'gather_indices' is the starting indices for gather.
|
||||
gather_dim_numbers.set_index_vector_dim(warp_dims - 1);
|
||||
gather_dim_numbers.add_collapsed_slice_dims(0);
|
||||
gather_dim_numbers.add_start_index_map(0);
|
||||
// Since input is of dimension [batch, height(y), width(x), channel], and warp
|
||||
// is of dimension [batch, x, y], the ordering of x, y here needs to be
|
||||
// swapped when gathering.
|
||||
gather_dim_numbers.add_start_index_map(2);
|
||||
gather_dim_numbers.add_start_index_map(1);
|
||||
// Data dimensions are [batch, x, y, channel].
|
||||
// Output dimensions are [batch, dim_0, ... dim_n, 2, 2, data_channels].
|
||||
auto neighbors_data = xla::Gather(data, gather_indices, gather_dim_numbers,
|
||||
/*slice_sizes=*/{1, 2, 2, data_channels});
|
||||
// Collapse the ...,2,2,... dimensions into ...,4,...
|
||||
return xla::Collapse(neighbors_data, {warp_dims - 1, warp_dims});
|
||||
}
|
||||
|
||||
// Scatter 'updates' tensor to 'grad_data' based on 'indices'. Returns the
|
||||
// resulting tensor of dimension: [batch, dim_0, ...dim_n, 2, 2, data_channels].
|
||||
// This function can also be seen as the inverse of 'Gather2by2Neighbors'.
|
||||
XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices,
|
||||
XlaOp updates, int64 warp_dims,
|
||||
xla::PrimitiveType xla_type) {
|
||||
xla::ScatterDimensionNumbers scatter_dim_numbers;
|
||||
const int64 neighbor_data_dimensions = warp_dims + 2;
|
||||
// Since the Scatter output dimensions are [batch, dim_0, ... dim_n, 2, 2,
|
||||
// data_channels], the update window dimensions is the last 3 dimensions.
|
||||
scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 3);
|
||||
scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 2);
|
||||
scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 1);
|
||||
scatter_dim_numbers.set_index_vector_dim(warp_dims - 1);
|
||||
|
||||
scatter_dim_numbers.add_inserted_window_dims(0);
|
||||
scatter_dim_numbers.add_scatter_dims_to_operand_dims(0);
|
||||
// Since input is of dimension [batch, height(y), width(x), channel], and warp
|
||||
// is of dimension [batch, x, y], the ordering of x, y here needs to be
|
||||
// swapped when scattering.
|
||||
scatter_dim_numbers.add_scatter_dims_to_operand_dims(2);
|
||||
scatter_dim_numbers.add_scatter_dims_to_operand_dims(1);
|
||||
|
||||
return xla::Scatter(grad_data, indices, updates,
|
||||
xla::CreateScalarAddComputation(xla_type, ctx->builder()),
|
||||
scatter_dim_numbers);
|
||||
}
|
||||
|
||||
// Build computation the backprop into input 'data'.
|
||||
// Where input:
|
||||
// grad_output is of dimension [batch, dim_0, ...dim_n, channel]
|
||||
// ratio is of dimension [batch, dim_0, ...dim_n, 2]
|
||||
// gather_indices is of dimension [batch, dim_0, ...dim_n, 3]
|
||||
//
|
||||
// Output:
|
||||
// scatter-add to each 2x2 grad_data neighbor:
|
||||
// grad_data[fx, fy, chan] += output_grad * dx * dy
|
||||
// grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy
|
||||
// grad_data[fx, cy, chan] += output_grad * dx * (1 - dy)
|
||||
// grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy)
|
||||
// where (dx, dy) is (1 - ratio).
|
||||
XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
|
||||
XlaOp gather_indices, xla::PrimitiveType warp_type,
|
||||
TensorShape warp_shape, int64 data_channels,
|
||||
xla::Shape data_shape) {
|
||||
// Weights tensor has dimension [batch, dim_0, ... dim_n, 4].
|
||||
auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type);
|
||||
|
||||
auto warp_dims = warp_shape.dim_sizes();
|
||||
std::vector<int64> warp_dims_without_last_dims(warp_dims.begin(),
|
||||
warp_dims.end() - 1);
|
||||
|
||||
std::vector<int64> reshaped_weights_dims = warp_dims_without_last_dims;
|
||||
// Reshape the last dimension of size 4 to two dimensions [2, 2].
|
||||
reshaped_weights_dims.push_back(2);
|
||||
reshaped_weights_dims.push_back(2);
|
||||
std::vector<int64> reshape_dims(warp_shape.dims());
|
||||
std::iota(reshape_dims.begin(), reshape_dims.end(), 0);
|
||||
// The dimension is [batch, dim_0,..., dim_n, 2, 2].
|
||||
auto reshaped_weights = xla::Reshape(weights, /*dimensions=*/reshape_dims,
|
||||
/*new_sizes=*/reshaped_weights_dims);
|
||||
|
||||
std::vector<int64> weights_with_channels_dims = reshaped_weights_dims;
|
||||
weights_with_channels_dims.push_back(data_channels);
|
||||
auto weights_with_channels_shape =
|
||||
xla::ShapeUtil::MakeShape(warp_type, weights_with_channels_dims);
|
||||
std::vector<int64> reshaped_weights_indices(reshaped_weights_dims.size());
|
||||
std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(),
|
||||
0);
|
||||
|
||||
// The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel].
|
||||
auto broadcast_reshaped_weights = xla::BroadcastInDim(
|
||||
reshaped_weights, weights_with_channels_shape, reshaped_weights_indices);
|
||||
|
||||
std::vector<int64> grad_output_indices(warp_dims_without_last_dims.size());
|
||||
std::iota(grad_output_indices.begin(), grad_output_indices.end(), 0);
|
||||
grad_output_indices.push_back(weights_with_channels_dims.size() - 1);
|
||||
XlaOp broadcast_grad_output = xla::BroadcastInDim(
|
||||
grad_output, weights_with_channels_shape, grad_output_indices);
|
||||
|
||||
auto grad_output_multiply_weights =
|
||||
broadcast_grad_output * broadcast_reshaped_weights;
|
||||
|
||||
auto grad_data = xla::ConstantLiteral(
|
||||
ctx->builder(), xla::Literal::CreateFromShape(data_shape));
|
||||
|
||||
return ScatterToGradData(ctx, grad_data, gather_indices,
|
||||
grad_output_multiply_weights, warp_shape.dims(),
|
||||
warp_type);
|
||||
}
|
||||
|
||||
// Build computation for the backprop into input 'warp'.
|
||||
// Where input:
|
||||
// warp is of dimension [batch, dim_0, ...dim_n, 2]
|
||||
// grad_output is of dimension [batch, dim_0, ...dim_n, channel]
|
||||
// ratio is of dimension [batch, dim_0, ...dim_n, 2]
|
||||
// gather_indices is of dimension [batch, dim_0, ...dim_n, 3]
|
||||
// data is of dimension [batch, x, y, channel]
|
||||
//
|
||||
// Output (simplified by ignoring the batch dimensions):
|
||||
// Since the forward path has:
|
||||
// output = dot(weights * neighbors)
|
||||
// The backprop into warp will therefore be:
|
||||
// grad_warp = output_grad * d_output / d_warp
|
||||
// = output_grad * (d_weights / d_warp * neighbors + d_neighbors /
|
||||
// d_warp * weight)
|
||||
// Where:
|
||||
// d_weights / d_warp_x = [-(1 - py), (1 - py), -py, py]
|
||||
// d_weights / d_warp_y = [-(1 - px), -px, (1-px), px]
|
||||
// and
|
||||
// d_neighbors / d_warp_x = 0
|
||||
//
|
||||
// Therefore:
|
||||
// grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy)
|
||||
// grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy)
|
||||
//
|
||||
// where (px, py) is warp, (fx, fy) is the left top corner and (cx, cy) is the
|
||||
// bottom right corner in a 2x2 neighborhood.
|
||||
XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
|
||||
XlaOp gather_indices, XlaOp data,
|
||||
TensorShape warp_shape, int64 data_channels,
|
||||
xla::PrimitiveType data_type) {
|
||||
auto warp_dims = warp_shape.dim_sizes();
|
||||
std::vector<int64> warp_dims_without_last_dims(warp_dims.begin(),
|
||||
warp_dims.end() - 1);
|
||||
|
||||
std::vector<int64> neighbor_broadcast_dims = warp_dims_without_last_dims;
|
||||
neighbor_broadcast_dims.push_back(4);
|
||||
|
||||
// With dimension [batch, dim_0, ...dim_n, 4]
|
||||
auto neighbor_broadcast_shape =
|
||||
xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims);
|
||||
|
||||
// The dimension is [batch, dim_0, ... dim_n, 4, data_channels]
|
||||
auto neighbors_data = Gather2by2Neighbors(
|
||||
ctx->builder(), data, gather_indices, data_channels, warp_shape.dims());
|
||||
|
||||
const int64 last_warp_dim = warp_shape.dims() - 1;
|
||||
|
||||
// Since we will be creating the dot product of:
|
||||
// lhs: [batch, dim_0, ...dim_n, 4]
|
||||
// and
|
||||
// rhs: [batch, dim_0, ...dim_n, 4, data_channels]
|
||||
// we choose the last dimension of lhs and the second last dimension of rhs,
|
||||
// with size 4, as the contracting dimension.
|
||||
xla::DotDimensionNumbers dot_dims;
|
||||
for (int i = 0; i < warp_shape.dims() - 1; ++i) {
|
||||
dot_dims.add_lhs_batch_dimensions(i);
|
||||
dot_dims.add_rhs_batch_dimensions(i);
|
||||
}
|
||||
dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1);
|
||||
dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1);
|
||||
|
||||
// img_cxcy - img_fxcy
|
||||
auto bottom_right_minus_bottom_left = xla::DotGeneral(
|
||||
xla::BroadcastInDim(
|
||||
xla::ConvertElementType(
|
||||
xla::ConstantR1<float>(ctx->builder(), {0, 0, -1, 1}), data_type),
|
||||
neighbor_broadcast_shape, {last_warp_dim}),
|
||||
neighbors_data, dot_dims, /*precision_config=*/nullptr);
|
||||
|
||||
// img_cxfy - img_fxfy
|
||||
auto top_right_minus_top_left = xla::DotGeneral(
|
||||
xla::BroadcastInDim(
|
||||
xla::ConvertElementType(
|
||||
xla::ConstantR1<float>(ctx->builder(), {-1, 1, 0, 0}), data_type),
|
||||
neighbor_broadcast_shape, {last_warp_dim}),
|
||||
neighbors_data, dot_dims, /*precision_config=*/nullptr);
|
||||
|
||||
// img_cxcy - img_cxfy
|
||||
auto bottom_right_minus_top_right = xla::DotGeneral(
|
||||
xla::BroadcastInDim(
|
||||
xla::ConvertElementType(
|
||||
xla::ConstantR1<float>(ctx->builder(), {0, -1, 0, 1}), data_type),
|
||||
neighbor_broadcast_shape, {last_warp_dim}),
|
||||
neighbors_data, dot_dims, /*precision_config=*/nullptr);
|
||||
|
||||
// img_fxcy - img_fxfy
|
||||
auto bottom_left_minus_top_left = xla::DotGeneral(
|
||||
xla::BroadcastInDim(
|
||||
xla::ConvertElementType(
|
||||
xla::ConstantR1<float>(ctx->builder(), {-1, 0, 1, 0}), data_type),
|
||||
neighbor_broadcast_shape, {last_warp_dim}),
|
||||
neighbors_data, dot_dims, /*precision_config=*/nullptr);
|
||||
|
||||
// Slice out x and y.
|
||||
auto weight_x = xla::SliceInDim(ratio, /*start_index=*/0, /*limit_index=*/1,
|
||||
/*stride=*/1, /*dimno=*/last_warp_dim);
|
||||
auto weight_y = xla::SliceInDim(ratio, /*start_index=*/1, /*limit_index=*/2,
|
||||
/*stride=*/1, /*dimno=*/last_warp_dim);
|
||||
|
||||
// Build 1 - y and 1 - x.
|
||||
auto one_minus_y = xla::One(ctx->builder(), data_type) - weight_y;
|
||||
auto one_minus_x = xla::One(ctx->builder(), data_type) - weight_x;
|
||||
|
||||
auto x_before_reduce =
|
||||
grad_output * weight_y * bottom_right_minus_bottom_left +
|
||||
one_minus_y * top_right_minus_top_left;
|
||||
|
||||
std::vector<int64> reshaped_sizes = warp_dims_without_last_dims;
|
||||
reshaped_sizes.push_back(1);
|
||||
|
||||
std::vector<int64> reshaped_dims(warp_dims_without_last_dims.size());
|
||||
std::iota(reshaped_dims.begin(), reshaped_dims.end(), 0);
|
||||
|
||||
// Reduce-add along the channel dimension.
|
||||
auto x_result =
|
||||
xla::Reduce(x_before_reduce, xla::Zero(ctx->builder(), data_type),
|
||||
xla::CreateScalarAddComputation(data_type, ctx->builder()),
|
||||
{last_warp_dim});
|
||||
// Reshape before concatenating with y values.
|
||||
XlaOp reshaped_x = xla::Reshape(x_result, reshaped_dims, reshaped_sizes);
|
||||
|
||||
auto y_before_reduce = grad_output * weight_x * bottom_right_minus_top_right +
|
||||
one_minus_x * bottom_left_minus_top_left;
|
||||
// Reduce-add along the channel dimension.
|
||||
auto y_result =
|
||||
xla::Reduce(y_before_reduce, xla::Zero(ctx->builder(), data_type),
|
||||
|
||||
xla::CreateScalarAddComputation(data_type, ctx->builder()),
|
||||
{last_warp_dim});
|
||||
XlaOp reshaped_y = xla::Reshape(y_result, reshaped_dims, reshaped_sizes);
|
||||
|
||||
return xla::ConcatInDim(ctx->builder(), {reshaped_x, reshaped_y},
|
||||
last_warp_dim);
|
||||
}
|
||||
|
||||
class ResamplerOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ResamplerOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape data_shape = ctx->InputShape("data");
|
||||
OP_REQUIRES(ctx, data_shape.dims() == 4,
|
||||
errors::InvalidArgument("data must be 4-dimensional",
|
||||
data_shape.DebugString()));
|
||||
const int64 data_channels = data_shape.dim_size(3);
|
||||
xla::PrimitiveType data_type = ctx->input_xla_type(0);
|
||||
|
||||
TensorShape warp_shape = ctx->InputShape("warp");
|
||||
OP_REQUIRES(ctx, warp_shape.dims() >= 2,
|
||||
errors::InvalidArgument("warp must be at least 2-dimensional",
|
||||
warp_shape.DebugString()));
|
||||
for (int size : warp_shape.dim_sizes()) {
|
||||
OP_REQUIRES(ctx, size > 0,
|
||||
errors::InvalidArgument("warp sizes must be positive, got [",
|
||||
size, "]"));
|
||||
}
|
||||
const int64 last_warp_dim = warp_shape.dims() - 1;
|
||||
// Last dimension of warp shape must be of size 2.
|
||||
OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2,
|
||||
errors::InvalidArgument(
|
||||
"the last dimension of warp must be exactly size 2."));
|
||||
|
||||
XlaOp data = ctx->Input("data");
|
||||
XlaOp warp = ctx->Input("warp");
|
||||
|
||||
// Find the coordinates of the top left corner for the 2x2 region to be
|
||||
// sampled from. The dimensions are (batch, dim_0, ... dim_n, 2) where the
|
||||
// last dimension of size 2 in turn is [x, y].
|
||||
XlaOp top_left = xla::ConvertElementType(warp, xla::U32);
|
||||
|
||||
auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape);
|
||||
|
||||
// The dimension is [batch, dim_0, ... dim_n, 4, data_channels]
|
||||
auto neighbors_data = Gather2by2Neighbors(
|
||||
ctx->builder(), data, gather_indices, data_channels, warp_shape.dims());
|
||||
|
||||
// Dimensions are [batch, dim_0, ... dim_n, 2].
|
||||
XlaOp ratio = warp - xla::ConvertElementType(top_left, data_type);
|
||||
|
||||
// Obtain the bilinear blending weights, the dimension is [batch, dim_0,
|
||||
// ...dim_n, 4].
|
||||
auto weights = BilinearWeights(ctx, ratio, warp_shape, data_type);
|
||||
|
||||
// Since we will be creating the dot product of:
|
||||
// lhs: [batch, dim_0, ...dim_n, 4]
|
||||
// and
|
||||
// rhs: [batch, dim_0, ...dim_n, 4, data_channels]
|
||||
// we choose the last dimension of lhs and the second last dimension of rhs,
|
||||
// with size 4, as the contracting dimension.
|
||||
xla::DotDimensionNumbers dot_dims;
|
||||
for (int i = 0; i < warp_shape.dims() - 1; ++i) {
|
||||
dot_dims.add_lhs_batch_dimensions(i);
|
||||
dot_dims.add_rhs_batch_dimensions(i);
|
||||
}
|
||||
dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1);
|
||||
dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1);
|
||||
|
||||
auto blended_pixels = xla::DotGeneral(weights, neighbors_data, dot_dims,
|
||||
/*precision_config=*/nullptr);
|
||||
|
||||
ctx->SetOutput(0, blended_pixels);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("Resampler"), ResamplerOp);
|
||||
|
||||
class ResamplerGradOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ResamplerGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
DataType output_dtype;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape data_shape_tf = ctx->InputShape("data");
|
||||
OP_REQUIRES(ctx, data_shape_tf.dims() == 4,
|
||||
errors::InvalidArgument("data must be 4-dimensional",
|
||||
data_shape_tf.DebugString()));
|
||||
const int64 data_channels = data_shape_tf.dim_size(3);
|
||||
xla::PrimitiveType data_type = ctx->input_xla_type(0);
|
||||
|
||||
TensorShape warp_shape = ctx->InputShape("warp");
|
||||
OP_REQUIRES(ctx, warp_shape.dims() >= 2,
|
||||
errors::InvalidArgument("warp must be at least 2-dimensional",
|
||||
warp_shape.DebugString()));
|
||||
for (int size : warp_shape.dim_sizes()) {
|
||||
OP_REQUIRES(ctx, size > 0,
|
||||
errors::InvalidArgument("warp sizes must be positive, got [",
|
||||
size, "]"));
|
||||
}
|
||||
// Last dimension of warp shape must be of size 2.
|
||||
OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2,
|
||||
errors::InvalidArgument(
|
||||
"the last dimension of warp must be exactly size 2."));
|
||||
xla::PrimitiveType warp_type = ctx->input_xla_type(1);
|
||||
|
||||
TensorShape output_grad_shape = ctx->InputShape("grad_output");
|
||||
OP_REQUIRES(
|
||||
ctx, output_grad_shape.dims() >= 2,
|
||||
errors::InvalidArgument("output_grad must be at least 2-dimensional",
|
||||
output_grad_shape.DebugString()));
|
||||
|
||||
// Dimensions are [batch, x, y, channel].
|
||||
XlaOp data = ctx->Input("data");
|
||||
xla::Shape data_shape = TensorShapeToXLAShape(data_type, data_shape_tf);
|
||||
|
||||
// Dimensions are [batch, dim_0, ...dim_n, 2].
|
||||
XlaOp warp = ctx->Input("warp");
|
||||
// Dimensions are [batch, dim_0, ...dim_n, channel].
|
||||
XlaOp grad_output = ctx->Input("grad_output");
|
||||
|
||||
// Find the top left corner coordinate for the region to be sampled from.
|
||||
// The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension
|
||||
// of size 2 in turn is [x, y].
|
||||
XlaOp top_left = xla::ConvertElementType(warp, xla::U32);
|
||||
|
||||
// Dimensions are [batch, dim_0, ... dim_n, 2]
|
||||
XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type);
|
||||
|
||||
// Indices for gathering neighboring pixels.
|
||||
auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape);
|
||||
|
||||
auto grad_data =
|
||||
CalculateGradData(ctx, grad_output, ratio, gather_indices, warp_type,
|
||||
warp_shape, data_channels, data_shape);
|
||||
|
||||
auto grad_warp =
|
||||
CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data,
|
||||
warp_shape, data_channels, data_type);
|
||||
|
||||
ctx->SetOutput(0, grad_data);
|
||||
ctx->SetOutput(1, grad_warp);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("ResamplerGrad"), ResamplerGradOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -17,8 +17,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/numeric.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -61,113 +63,79 @@ class ReverseSequenceOp : public XlaOpKernel {
|
||||
const auto seq_lens = context->Input(1);
|
||||
|
||||
const int64 batch_size = input_shape.dim_size(batch_dim_);
|
||||
if (batch_size == 0) {
|
||||
context->SetOutput(0, input);
|
||||
return;
|
||||
}
|
||||
|
||||
const DataType input_type = context->input_type(0);
|
||||
const DataType seq_lens_type = context->input_type(1);
|
||||
// Given the input
|
||||
//
|
||||
// 012345
|
||||
// 6789AB
|
||||
//
|
||||
// and sequence lens {2, 3} we:
|
||||
//
|
||||
// 1. Reverse and pad each row to get
|
||||
//
|
||||
// 543210XXXXXX
|
||||
// BA9876XXXXXX
|
||||
//
|
||||
// 2. Gather out the suffix from each row to get
|
||||
//
|
||||
// 10XXXX
|
||||
// 876XXX
|
||||
//
|
||||
// 3. Select from the input and the array created by (2) to get the result.
|
||||
//
|
||||
// 102345
|
||||
// 8769AB
|
||||
const xla::PrimitiveType input_type = context->input_xla_type(0);
|
||||
const xla::PrimitiveType seq_lens_type = context->input_xla_type(1);
|
||||
const int64 max_seq_len = input_shape.dim_size(seq_dim_);
|
||||
|
||||
xla::Shape input_xla_shape;
|
||||
OP_REQUIRES_OK(context, TensorShapeToXLAShape(input_type, input_shape,
|
||||
&input_xla_shape));
|
||||
xla::Shape seq_lens_xla_shape;
|
||||
OP_REQUIRES_OK(context, TensorShapeToXLAShape(seq_lens_type, seq_lens_shape,
|
||||
&seq_lens_xla_shape));
|
||||
xla::XlaOp rev = xla::Rev(input, {seq_dim_});
|
||||
|
||||
const auto tuple_shape = xla::ShapeUtil::MakeTupleShape({
|
||||
xla::ShapeUtil::MakeShape(seq_lens_xla_shape.element_type(), {}),
|
||||
seq_lens_xla_shape,
|
||||
input_xla_shape,
|
||||
});
|
||||
auto padding_config = xla::MakeNoPaddingConfig(input_shape.dims());
|
||||
padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high(
|
||||
max_seq_len);
|
||||
xla::XlaOp padded =
|
||||
xla::Pad(rev, xla::Zero(builder, input_type), padding_config);
|
||||
|
||||
// For each entry in the batch, reverse the sequence.
|
||||
// TODO(b/65689298): generalize the Map() operator to non-scalar cases and
|
||||
// use it here, instead of a While loop.
|
||||
// Form a start indices tensor with shape [2, batch_size]. For each batch
|
||||
// entry we have a (batch offset, seq offset) pair.
|
||||
xla::XlaOp start_indices = xla::ConcatInDim(
|
||||
builder,
|
||||
{
|
||||
xla::Iota(builder,
|
||||
xla::ShapeUtil::MakeShape(seq_lens_type, {1, batch_size}),
|
||||
/*iota_dimension=*/1),
|
||||
xla::Reshape(xla::ScalarLike(seq_lens, max_seq_len) - seq_lens,
|
||||
{1, batch_size}),
|
||||
},
|
||||
/*dimension=*/0);
|
||||
|
||||
// Condition: lambda (i, _, _): i < batch_size
|
||||
auto condition_builder =
|
||||
builder->CreateSubBuilder("reverse_sequence_condition");
|
||||
{
|
||||
auto param =
|
||||
xla::Parameter(condition_builder.get(), 0, tuple_shape, "param");
|
||||
auto i = xla::GetTupleElement(param, 0);
|
||||
xla::Lt(i, XlaHelpers::IntegerLiteral(condition_builder.get(),
|
||||
seq_lens_type, batch_size));
|
||||
xla::GatherDimensionNumbers dnums;
|
||||
// The first dimension of start_indices contains the batch/seq dim choice.
|
||||
dnums.set_index_vector_dim(0);
|
||||
dnums.add_start_index_map(batch_dim_);
|
||||
dnums.add_start_index_map(seq_dim_);
|
||||
|
||||
// All other dimensions other than the batch dim are offset dimensions.
|
||||
for (int i = 0; i < input_shape.dims(); ++i) {
|
||||
if (i != batch_dim_) {
|
||||
dnums.add_offset_dims(i);
|
||||
}
|
||||
}
|
||||
auto condition = condition_builder->Build();
|
||||
OP_REQUIRES_OK(context, condition.status());
|
||||
dnums.add_collapsed_slice_dims(batch_dim_);
|
||||
|
||||
auto body_builder = builder->CreateSubBuilder("reverse_sequence_body");
|
||||
{
|
||||
auto param = xla::Parameter(body_builder.get(), 0, tuple_shape, "param");
|
||||
auto i = xla::GetTupleElement(param, 0);
|
||||
auto seq_lens = xla::GetTupleElement(param, 1);
|
||||
auto output = xla::GetTupleElement(param, 2);
|
||||
auto slice_sizes = input_shape.dim_sizes();
|
||||
slice_sizes[batch_dim_] = 1;
|
||||
|
||||
// seq_len is the sequence length of the current batch element (rank 1)
|
||||
auto seq_len = xla::DynamicSlice(seq_lens, xla::Reshape(i, {1}), {1});
|
||||
xla::XlaOp output = xla::Gather(padded, start_indices, dnums, slice_sizes);
|
||||
|
||||
// Indices is the offset of the batch element in the input.
|
||||
auto batch_element_indices =
|
||||
xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type),
|
||||
{input_shape.dims()});
|
||||
batch_element_indices = xla::DynamicUpdateSlice(
|
||||
batch_element_indices, xla::Reshape(i, {1}),
|
||||
xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(),
|
||||
seq_lens_type, batch_dim_),
|
||||
{1}));
|
||||
|
||||
// Slice out the current batch element and pad it out in the sequence
|
||||
// dimension.
|
||||
TensorShape slice_shape = input_shape;
|
||||
slice_shape.set_dim(batch_dim_, 1);
|
||||
slice_shape.set_dim(seq_dim_, max_seq_len);
|
||||
auto slice = xla::DynamicSlice(output, batch_element_indices,
|
||||
slice_shape.dim_sizes());
|
||||
auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims());
|
||||
padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high(
|
||||
slice_shape.dim_size(seq_dim_));
|
||||
slice = xla::Pad(slice, XlaHelpers::Zero(body_builder.get(), input_type),
|
||||
padding_config);
|
||||
|
||||
// Now slice out the reversed sequence from its actual start.
|
||||
// sequence_start_indices is the offset of the start of the reversed
|
||||
// sequence in the input. The slice will go into the padding, however, we
|
||||
// will mask off these elements and replace them with elements from the
|
||||
// original input so their values do not matter.
|
||||
auto sequence_start_indices =
|
||||
xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type),
|
||||
{slice_shape.dims()});
|
||||
sequence_start_indices = xla::DynamicUpdateSlice(
|
||||
sequence_start_indices,
|
||||
xla::Sub(XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
|
||||
max_seq_len),
|
||||
seq_len),
|
||||
xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(),
|
||||
seq_lens_type, seq_dim_),
|
||||
{1}));
|
||||
slice = xla::DynamicSlice(slice, sequence_start_indices,
|
||||
slice_shape.dim_sizes());
|
||||
|
||||
// Shift the reversed sequence to the left.
|
||||
output = xla::DynamicUpdateSlice(output, slice, batch_element_indices);
|
||||
|
||||
xla::Tuple(
|
||||
body_builder.get(),
|
||||
{xla::Add(i, XlaHelpers::One(body_builder.get(), seq_lens_type)),
|
||||
seq_lens, output});
|
||||
}
|
||||
auto body = body_builder->Build();
|
||||
OP_REQUIRES_OK(context, body.status());
|
||||
|
||||
auto loop_output = xla::While(
|
||||
condition.ValueOrDie(), body.ValueOrDie(),
|
||||
xla::Tuple(builder, {XlaHelpers::Zero(builder, seq_lens_type), seq_lens,
|
||||
xla::Rev(input, {seq_dim_})}));
|
||||
auto output = xla::GetTupleElement(loop_output, 2);
|
||||
|
||||
// Mask out elements after the sequence length.
|
||||
xla::XlaOp iota =
|
||||
xla::Iota(builder, seq_lens_xla_shape.element_type(), max_seq_len);
|
||||
// Mask out elements after the sequence length, and copy the corresponding
|
||||
// elements from the input.
|
||||
xla::XlaOp iota = xla::Iota(builder, seq_lens_type, max_seq_len);
|
||||
std::vector<int64> dims(input_shape.dims(), 1);
|
||||
dims[batch_dim_] = batch_size;
|
||||
auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_});
|
||||
|
@ -126,7 +126,9 @@ class StackOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackV2").CompileTimeConstantInput("max_size"), StackOp);
|
||||
REGISTER_XLA_OP(
|
||||
Name("StackV2").CompileTimeConstantInput("max_size").CompilationOnly(),
|
||||
StackOp);
|
||||
|
||||
class StackPushOp : public XlaOpKernel {
|
||||
public:
|
||||
@ -173,7 +175,7 @@ class StackPushOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp);
|
||||
REGISTER_XLA_OP(Name("StackPushV2").CompilationOnly(), StackPushOp);
|
||||
|
||||
class StackPopOp : public XlaOpKernel {
|
||||
public:
|
||||
@ -227,7 +229,7 @@ class StackPopOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp);
|
||||
REGISTER_XLA_OP(Name("StackPopV2").CompilationOnly(), StackPopOp);
|
||||
|
||||
class StackCloseOp : public XlaOpKernel {
|
||||
public:
|
||||
@ -241,7 +243,7 @@ class StackCloseOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp);
|
||||
REGISTER_XLA_OP(Name("StackCloseV2").CompilationOnly(), StackCloseOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
@ -75,6 +76,222 @@ Status CheckFeedFetchNameConflicts(const string& kind,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to
|
||||
// `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`.
|
||||
Status CopyAssociatedFunctions(Graph* g,
|
||||
const FunctionLibraryDefinition* lookup_fld,
|
||||
FunctionLibraryDefinition* fld) {
|
||||
for (Node* n : g->op_nodes()) {
|
||||
for (const auto& associated_function :
|
||||
GetAssociatedFunctions(*n, lookup_fld)) {
|
||||
switch (associated_function.type()) {
|
||||
case AssociatedFunctionInfo::kFunctionCallNode: {
|
||||
const FunctionDef* fdef =
|
||||
lookup_fld->Find(associated_function.func_name());
|
||||
if (!fdef) {
|
||||
return errors::Internal(
|
||||
"Cannot find function ", associated_function.func_name(),
|
||||
" for function call node ", n->DebugString());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef));
|
||||
break;
|
||||
}
|
||||
case AssociatedFunctionInfo::kSymbolicGradient:
|
||||
case AssociatedFunctionInfo::kFunctionAttr:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For graph `g`, replaces _Arg nodes whose "index" attribute is in
|
||||
// `const_input_index_to_node` with Const nodes.
|
||||
Status ReplaceArgUsageWithConstNode(
|
||||
Graph* g,
|
||||
const std::unordered_map<int, const Node*>& const_input_index_to_node) {
|
||||
// Collect all _Arg nodes.
|
||||
std::unordered_map<int, Node*> arg_nodes;
|
||||
for (Node* n : g->op_nodes()) {
|
||||
if (n->type_string() == FunctionLibraryDefinition::kArgOp) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
arg_nodes[index] = n;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& iter : const_input_index_to_node) {
|
||||
int arg_index = iter.first;
|
||||
Node* const_node = g->CopyNode(iter.second);
|
||||
Node* arg_node = arg_nodes[arg_index];
|
||||
|
||||
// Collect all usages of the _Arg node.
|
||||
struct OutEdgeInfo {
|
||||
int dst_node_id, dst_input;
|
||||
};
|
||||
std::vector<OutEdgeInfo> usages;
|
||||
for (const Edge* e : arg_node->out_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
usages.push_back({e->dst()->id(), e->dst_input()});
|
||||
}
|
||||
|
||||
for (int i = 0; i < usages.size(); i++) {
|
||||
// Make a copy of `usage_node`, and change its input to const node.
|
||||
Node* usage_node = g->FindNodeId(usages[i].dst_node_id);
|
||||
NodeDef replace_def = usage_node->def();
|
||||
*replace_def.mutable_input(usages[i].dst_input) = const_node->name();
|
||||
TF_ASSIGN_OR_RETURN(Node * replace_node,
|
||||
ReplaceNode(g, usage_node, replace_def));
|
||||
const Edge* usage_edge;
|
||||
TF_RETURN_IF_ERROR(
|
||||
replace_node->input_edge(usages[i].dst_input, &usage_edge));
|
||||
g->RemoveEdge(usage_edge);
|
||||
g->AddEdge(const_node, 0, replace_node, usages[i].dst_input);
|
||||
|
||||
// Later entries in `usages` might have `usage_node` as dst node, but
|
||||
// `usage_node` is removed. Replace such entries with `replace_node`.
|
||||
for (int j = i + 1; j < usages.size(); j++) {
|
||||
if (usages[j].dst_node_id == usages[i].dst_node_id) {
|
||||
usages[j].dst_node_id = replace_node->id();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For a node's function attr (e.g. then/else branch for "If" nodes), rewrites
|
||||
// the function to replace _Arg nodes in `const_input_index_to_node` with Const
|
||||
// inputs.
|
||||
Status PropagateConstIntoFuncAttr(
|
||||
Node* n, const string& attr_name,
|
||||
const std::unordered_map<int, const Node*>& const_input_index_to_node,
|
||||
const FunctionLibraryDefinition* lookup_fld,
|
||||
FunctionLibraryDefinition* fld) {
|
||||
// Instantiate the function.
|
||||
NameAttrList func_attr;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr));
|
||||
const FunctionDef* fdef = lookup_fld->Find(func_attr.name());
|
||||
if (!fdef) {
|
||||
return errors::Internal("Cannot find function ", func_attr.name(),
|
||||
" for node ", n->name());
|
||||
}
|
||||
FunctionBody* fbody;
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
|
||||
*fdef, AttrSlice(&func_attr.attr()), lookup_fld,
|
||||
[lookup_fld](const string& op, const OpDef** sig) {
|
||||
return lookup_fld->LookUpOpDef(op, sig);
|
||||
},
|
||||
&fbody));
|
||||
std::unique_ptr<FunctionBody> fbody_deleter(fbody);
|
||||
|
||||
// Rewrite _Arg usages with Const node.
|
||||
Graph* func_graph = fbody->graph;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node));
|
||||
|
||||
// Save rewritten function.
|
||||
FunctionDef replace_fdef;
|
||||
string new_func_name =
|
||||
fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_"));
|
||||
TF_RETURN_IF_ERROR(
|
||||
GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef));
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
|
||||
|
||||
// Change the node to use rewritten function.
|
||||
func_attr.set_name(new_func_name);
|
||||
n->ClearAttr(attr_name);
|
||||
n->AddAttr(attr_name, func_attr);
|
||||
|
||||
// Copy associated functions.
|
||||
TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For an "If" node in graph `g`, if it has Const node inputs, rewrite its
|
||||
// then/else branch function to replace _Arg nodes with those Const inputs.
|
||||
Status PropagateConstIntoIfNode(Graph* g, Node* if_node,
|
||||
const FunctionLibraryDefinition* lookup_fld,
|
||||
FunctionLibraryDefinition* fld) {
|
||||
// Notice that first input for If node is predicate; other inputs are function
|
||||
// inputs.
|
||||
std::unordered_map<int, const Node*> const_input_index_to_node;
|
||||
for (int i = 1; i < if_node->num_inputs(); i++) {
|
||||
const Node* input_node;
|
||||
TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node));
|
||||
if (input_node->type_string() == "Const") {
|
||||
const_input_index_to_node[i - 1] = input_node;
|
||||
}
|
||||
}
|
||||
if (const_input_index_to_node.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Rewrite "then_branch" and "else_branch" function, replace usage of those
|
||||
// _Arg nodes with corresponding const node.
|
||||
for (const auto& attr_name :
|
||||
std::vector<string>{"then_branch", "else_branch"}) {
|
||||
TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
|
||||
if_node, attr_name, const_input_index_to_node, lookup_fld, fld));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For a "While" node in graph `g`, if it has Const node inputs, rewrite its
|
||||
// cond/body function to replace _Arg nodes with those Const inputs.
|
||||
Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
|
||||
const FunctionLibraryDefinition* lookup_fld,
|
||||
FunctionLibraryDefinition* fld) {
|
||||
// For "While" node, we should only replace _Arg nodes which are loop
|
||||
// invariants. For such _Arg nodes, the return value's input will come
|
||||
// directly from the corresponding arg.
|
||||
std::unordered_map<int, const Node*> const_input_index_to_node;
|
||||
NameAttrList body_attr;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr));
|
||||
const FunctionDef* body_func = lookup_fld->Find(body_attr.name());
|
||||
if (!body_func) {
|
||||
return errors::Internal("Cannot find body function ", body_attr.name(),
|
||||
" for While node ", while_node->name());
|
||||
}
|
||||
for (int i = 0; i < while_node->num_inputs(); i++) {
|
||||
const Node* input_node;
|
||||
TF_RETURN_IF_ERROR(while_node->input_node(i, &input_node));
|
||||
if (input_node->type_string() != "Const") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if i-th retval's input comes from i-th arg directly.
|
||||
const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i);
|
||||
auto output_arg_input = body_func->ret().find(output_arg.name());
|
||||
if (output_arg_input == body_func->ret().end()) {
|
||||
return errors::Internal("Cannot find input for output arg ",
|
||||
output_arg.name(), " in function ",
|
||||
body_attr.name());
|
||||
}
|
||||
const OpDef_ArgDef& input_arg = body_func->signature().input_arg(i);
|
||||
if (output_arg_input->second != input_arg.name()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const_input_index_to_node[i] = input_node;
|
||||
}
|
||||
if (const_input_index_to_node.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Rewrite "cond" and "body" function, replace usage of those _Arg nodes with
|
||||
// corresponding const node.
|
||||
for (const auto& attr_name : std::vector<string>{"cond", "body"}) {
|
||||
TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
|
||||
while_node, attr_name, const_input_index_to_node, lookup_fld, fld));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
|
||||
@ -520,4 +737,17 @@ xla::StatusOr<Node*> BuildIdentityNode(
|
||||
return id_node;
|
||||
}
|
||||
|
||||
Status PropagateConstIntoFunctionalNodes(
|
||||
Graph* g, const FunctionLibraryDefinition* lookup_fld,
|
||||
FunctionLibraryDefinition* fld) {
|
||||
for (Node* n : g->op_nodes()) {
|
||||
if (n->type_string() == "If") {
|
||||
TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
|
||||
} else if (n->type_string() == "While") {
|
||||
TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -183,6 +183,20 @@ xla::StatusOr<Node*> BuildIdentityNode(Graph* graph, const string& node_name,
|
||||
DataType dtype, const Node* input,
|
||||
absl::optional<string> requested_device);
|
||||
|
||||
// For "If"/"While" nodes, if some of their inputs are Const nodes, rewrite
|
||||
// body functions to use the Const nodes instead of original _Arg nodes.
|
||||
//
|
||||
// For example, say we have the following computation:
|
||||
// shape = constant_op.constant([1])
|
||||
// return tf.cond(pred, lambda: tf.ones(shape), lambda: tf.zeros(shape))
|
||||
// If we do not rewrite then/else function, they will use _Arg node as shape
|
||||
// input for tf.ones/tf.zeros. But XLA requires that shape input to be compile
|
||||
// time constant, so XLA compilation will fail. This rewriting process will
|
||||
// change the shape input to Const node.
|
||||
Status PropagateConstIntoFunctionalNodes(
|
||||
Graph* g, const FunctionLibraryDefinition* lookup_fld,
|
||||
FunctionLibraryDefinition* fld);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
|
||||
|
@ -756,6 +756,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
CompilationResult* result) {
|
||||
VLOG(1) << "Executing graph symbolically to populate XlaBuilder.";
|
||||
|
||||
TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
|
||||
graph.get(), options_.flib_def, local_flib_def_.get()));
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "XlaCompiler::CompileGraph: "
|
||||
<< dump_graph::DumpGraphToFile(
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
@ -129,21 +130,27 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
// Lazily register the CPU and GPU JIT devices the first time
|
||||
// GetCompilationDevice is called.
|
||||
static void* registration_init = [®istry]() {
|
||||
legacy_flags::MarkForCompilationPassFlags* flags =
|
||||
legacy_flags::GetMarkForCompilationPassFlags();
|
||||
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
|
||||
|
||||
mutex_lock lock(registry.mutex_);
|
||||
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) {
|
||||
DeviceRegistration& registration =
|
||||
registry.compilation_devices_[DEVICE_CPU];
|
||||
registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
|
||||
registration.requires_compilation = false;
|
||||
registration.enable_jit_by_default = false;
|
||||
registration.autoclustering_policy =
|
||||
cpu_global_jit
|
||||
? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally
|
||||
: XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested;
|
||||
registration.compile_resource_ops = false;
|
||||
}
|
||||
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) {
|
||||
DeviceRegistration& registration =
|
||||
registry.compilation_devices_[DEVICE_GPU];
|
||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||
registration.requires_compilation = false;
|
||||
registration.enable_jit_by_default = true;
|
||||
registration.autoclustering_policy =
|
||||
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally;
|
||||
registration.compile_resource_ops = false;
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -66,19 +66,26 @@ class XlaOpRegistry {
|
||||
public:
|
||||
typedef OpKernel* (*Factory)(OpKernelConstruction*);
|
||||
|
||||
enum class AutoclusteringPolicy {
|
||||
// Enable autoclustering if the user requests it, e.g., via
|
||||
// experimental_jit_scope. Does not autocluster if the JIT is enabled
|
||||
// globally (e.g., via the OptimizerOptions in the TF session
|
||||
// configuration.)
|
||||
kIfExplicitlyRequested,
|
||||
// Enable autoclustering if explicitly requested, or if the JIT is enabled
|
||||
// globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N.
|
||||
kIfEnabledGlobally,
|
||||
// Always try to autocluster ops placed on this device.
|
||||
kAlways,
|
||||
};
|
||||
|
||||
// Describes how to compile operators assigned to a device.
|
||||
struct DeviceRegistration {
|
||||
// The name of the an XLA compilation device to use to compile code.
|
||||
string compilation_device_name;
|
||||
|
||||
// Do operators assigned to this device require compilation?
|
||||
bool requires_compilation;
|
||||
|
||||
// If !requires_compilation, should we try to JIT operators on this device
|
||||
// when XLA JIT compilation is enabled globally via the SessionOptions?
|
||||
// (It is still possible to explicitly mark operators to JIT compile, even
|
||||
// if enable_jit_by_default is false.)
|
||||
bool enable_jit_by_default;
|
||||
// When should we autocluster operators assigned to this device?
|
||||
AutoclusteringPolicy autoclustering_policy;
|
||||
|
||||
// Enable compilation of operators that use DT_RESOURCE types?
|
||||
bool compile_resource_ops = false;
|
||||
|
@ -7,6 +7,7 @@ package_group(
|
||||
packages = [
|
||||
"//tensorflow/compiler/...",
|
||||
"//tensorflow/contrib/tpu/...",
|
||||
"//third_party/py/jax/...",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
<p align="center">
|
||||
<img width="200" src="xlalogo.png"/>
|
||||
<img width="200" src="./g3doc/images/xlalogo.png"/>
|
||||
</p>
|
||||
|
||||
XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear
|
||||
algebra that optimizes TensorFlow computations. See the
|
||||
[documentation](https://www.tensorflow.org/performance/xla/) for more details.
|
||||
algebra that optimizes TensorFlow computations. See the [documentation](./g3doc/overview.md).
|
||||
|
@ -1,3 +1,3 @@
|
||||
# XLA: Accelerated Linear Algebra
|
||||
|
||||
These are the docs for: https://www.tensorflow.org/extend/xla
|
||||
These are the docs for: https://www.tensorflow.org/xla
|
||||
|
29
tensorflow/compiler/xla/g3doc/_book.yaml
Normal file
29
tensorflow/compiler/xla/g3doc/_book.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
upper_tabs:
|
||||
# Tabs left of dropdown menu
|
||||
- include: /_upper_tabs_left.yaml
|
||||
- include: /api_docs/_upper_tabs_api.yaml
|
||||
# Dropdown menu
|
||||
- name: Ecosystem
|
||||
path: /ecosystem
|
||||
is_default: true
|
||||
menu:
|
||||
- include: /ecosystem/_menu_toc.yaml
|
||||
lower_tabs:
|
||||
# Subsite tabs
|
||||
other:
|
||||
- name: Guide
|
||||
contents:
|
||||
- title: XLA overview
|
||||
path: /xla/overview
|
||||
- title: Broadcasting semantics
|
||||
path: /xla/broadcasting
|
||||
- title: Developing a new backend for XLA
|
||||
path: /xla/developing_new_backend
|
||||
- title: Using JIT compilation
|
||||
path: /xla/jit
|
||||
- title: Operation semantics
|
||||
path: /xla/operation_semantics
|
||||
- title: Shapes and layout
|
||||
path: /xla/shapes
|
||||
- title: Using AOT compilation
|
||||
path: /xla/tfcompile
|
35
tensorflow/compiler/xla/g3doc/_index.yaml
Normal file
35
tensorflow/compiler/xla/g3doc/_index.yaml
Normal file
@ -0,0 +1,35 @@
|
||||
book_path: /xla/_book.yaml
|
||||
project_path: /xla/_project.yaml
|
||||
description: <!--no description-->
|
||||
landing_page:
|
||||
custom_css_path: /site-assets/css/style.css
|
||||
rows:
|
||||
- heading: XLA is a compiler that optimizes TensorFlow computations.
|
||||
items:
|
||||
- classname: devsite-landing-row-50
|
||||
description: >
|
||||
XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear
|
||||
algebra that optimizes TensorFlow computations. The results are
|
||||
improvements in speed, memory usage, and portability on server and mobile
|
||||
platforms. The XLA framework is experimental and in active development.
|
||||
For details, read the <a href="./overview">XLA guide</a>.
|
||||
|
||||
- classname: devsite-landing-row-cards
|
||||
items:
|
||||
- heading: XLA - TensorFlow, compiled
|
||||
image_path: /ecosystem/images/tf-logo-card-16x9.png
|
||||
path: https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html
|
||||
buttons:
|
||||
- label: Read on Google Developers blog
|
||||
path: https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html
|
||||
- heading: XLA at the Dev Summit
|
||||
youtube_id: kAOanJczHA0
|
||||
buttons:
|
||||
- label: Watch the video
|
||||
path: https://www.youtube.com/watch?v=kAOanJczHA0
|
||||
- heading: XLA on GitHub
|
||||
image_path: /ecosystem/images/github-card-16x9.png
|
||||
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla
|
||||
buttons:
|
||||
- label: View on GitHub
|
||||
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla
|
@ -1,6 +1,6 @@
|
||||
name: XLA
|
||||
breadcrumb_name: XLA
|
||||
home_url: /extend/xla
|
||||
home_url: /xla/
|
||||
parent_project_metadata_path: /_project.yaml
|
||||
description: >
|
||||
XLA is a compiler-based linear algebra execution engine.
|
||||
|
@ -1,16 +0,0 @@
|
||||
toc:
|
||||
- heading: XLA
|
||||
- title: XLA overview
|
||||
path: /extend/xla/
|
||||
- title: Broadcasting semantics
|
||||
path: /extend/xla/broadcasting
|
||||
- title: Developing a new backend for XLA
|
||||
path: /extend/xla/developing_new_backend
|
||||
- title: Using JIT compilation
|
||||
path: /extend/xla/jit
|
||||
- title: Operation semantics
|
||||
path: /extend/xla/operation_semantics
|
||||
- title: Shapes and layout
|
||||
path: /extend/xla/shapes
|
||||
- title: Using AOT compilation
|
||||
path: /extend/xla/tfcompile
|
@ -29,8 +29,6 @@ namespace xla {
|
||||
/* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex(
|
||||
const Shape& shape, absl::Span<const int64> multi_index) {
|
||||
DCHECK_EQ(shape.dimensions_size(), multi_index.size());
|
||||
// Padding and nested layouts not supported yet.
|
||||
DCHECK_EQ(0, shape.layout().padded_dimensions_size());
|
||||
|
||||
for (size_t i = 0; i < multi_index.size(); ++i) {
|
||||
DCHECK_GE(multi_index[i], 0);
|
||||
@ -94,8 +92,6 @@ namespace xla {
|
||||
|
||||
/* static */ std::vector<int64> IndexUtil::LinearIndexToMultidimensionalIndex(
|
||||
const Shape& shape, int64 linear_index) {
|
||||
// Padding and nested layouts not supported yet.
|
||||
DCHECK_EQ(0, shape.layout().padded_dimensions_size());
|
||||
DCHECK_GE(linear_index, 0);
|
||||
DCHECK_LT(linear_index, ShapeUtil::ElementsIn(shape));
|
||||
|
||||
@ -133,18 +129,12 @@ namespace xla {
|
||||
|
||||
/* static */ int64 IndexUtil::GetDimensionStride(const Shape& shape,
|
||||
int64 dimension) {
|
||||
int64 pdim_size = LayoutUtil::PaddedDimensions(shape).size();
|
||||
int64 stride = 1;
|
||||
DCHECK(pdim_size == 0 || pdim_size == shape.dimensions_size());
|
||||
for (auto dim : LayoutUtil::MinorToMajor(shape)) {
|
||||
if (dim == dimension) {
|
||||
break;
|
||||
}
|
||||
if (pdim_size == 0) {
|
||||
stride *= shape.dimensions(dim);
|
||||
} else {
|
||||
stride *= LayoutUtil::PaddedDimension(shape, dim);
|
||||
}
|
||||
stride *= shape.dimensions()[dim];
|
||||
}
|
||||
return stride;
|
||||
}
|
||||
|
@ -61,8 +61,7 @@ class IndexUtil {
|
||||
static bool BumpIndices(const Shape& shape, absl::Span<int64> indices);
|
||||
|
||||
// Calculates the stride size (in number of elements, not byte size) of a
|
||||
// given logical shape dimension (from 0 to rank-1). If available, padded
|
||||
// dimensions are used.
|
||||
// given logical shape dimension (from 0 to rank-1).
|
||||
// Example:
|
||||
// GetDimensionStride(F32[5,8,10,4]{3,2,1,0}, 1) ==
|
||||
// sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10
|
||||
|
@ -201,8 +201,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
}
|
||||
|
||||
if (!ShapeUtil::IsArray(shape)) {
|
||||
if (layout.minor_to_major_size() != 0 ||
|
||||
layout.padded_dimensions_size() != 0) {
|
||||
if (layout.minor_to_major_size() != 0) {
|
||||
return InvalidArgument(
|
||||
"shape of primitive type %s should not have a non-trivial layout",
|
||||
PrimitiveType_Name(shape.element_type()));
|
||||
@ -241,28 +240,6 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
}
|
||||
dimensions_in_layout[dim] = true;
|
||||
}
|
||||
|
||||
if (layout.padded_dimensions_size() > 0) {
|
||||
if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) {
|
||||
return InvalidArgument(
|
||||
"layout has %d padded dimensions, but shape is rank %d",
|
||||
layout.padded_dimensions_size(), ShapeUtil::Rank(shape));
|
||||
}
|
||||
for (int i = 0; i < layout.padded_dimensions_size(); ++i) {
|
||||
if (layout.padded_dimensions(i) < shape.dimensions(i)) {
|
||||
return InvalidArgument(
|
||||
"for dimension %d, dimension padding (%d) is smaller than "
|
||||
"the dimension size (%d) of the shape",
|
||||
i, layout.padded_dimensions(i), shape.dimensions(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (layout.format() == SPARSE) {
|
||||
if (!layout.padded_dimensions().empty()) {
|
||||
return InvalidArgument("Sparse layout has padded dimensions");
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -303,38 +280,6 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
layout.minor_to_major().end(), std::greater<int64>());
|
||||
}
|
||||
|
||||
/* static */ bool LayoutUtil::IsPadded(const Shape& shape) {
|
||||
if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) ||
|
||||
shape.layout().padded_dimensions_size() == 0) {
|
||||
return false;
|
||||
}
|
||||
CHECK(IsDenseArray(shape)) << shape.ShortDebugString();
|
||||
CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size());
|
||||
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
|
||||
if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/* static */ absl::Span<const int64> LayoutUtil::PaddedDimensions(
|
||||
const Shape& shape) {
|
||||
CHECK(IsDenseArray(shape));
|
||||
return AsInt64Slice(shape.layout().padded_dimensions());
|
||||
}
|
||||
|
||||
/* static */ int64 LayoutUtil::PaddedDimension(const Shape& shape,
|
||||
int64 index) {
|
||||
CHECK(IsDenseArray(shape));
|
||||
return shape.layout().padded_dimensions(index);
|
||||
}
|
||||
|
||||
/* static */ PaddingValue LayoutUtil::GetPaddingValue(const Shape& shape) {
|
||||
CHECK(IsDenseArray(shape));
|
||||
return shape.layout().padding_value();
|
||||
}
|
||||
|
||||
/* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) {
|
||||
return ShapeUtil::IsArray(shape) && shape.has_layout() &&
|
||||
IsSparse(shape.layout());
|
||||
@ -513,13 +458,6 @@ std::ostream& operator<<(std::ostream& out, const Layout& layout) {
|
||||
for (int64 minor_to_major : layout.minor_to_major()) {
|
||||
hash_value = Hash64Combine(hash_value, hash<int64>()(minor_to_major));
|
||||
}
|
||||
|
||||
for (int64 padded_dim : layout.padded_dimensions()) {
|
||||
hash_value = Hash64Combine(hash_value, hash<int64>()(padded_dim));
|
||||
}
|
||||
|
||||
hash_value =
|
||||
Hash64Combine(hash_value, hash<PaddingValue>()(layout.padding_value()));
|
||||
hash_value = Hash64Combine(hash_value, layout.max_sparse_elements());
|
||||
|
||||
return hash_value;
|
||||
|
@ -104,23 +104,6 @@ class LayoutUtil {
|
||||
// more minor, and so on until dimension N-1 which is the minor.
|
||||
static bool IsMonotonicWithDim0Major(const Layout& layout);
|
||||
|
||||
// Returns whether the layout of the given shape has padding (a
|
||||
// padded_dimension value in Layout is greater than the corresponding
|
||||
// dimension size).
|
||||
static bool IsPadded(const Shape& shape);
|
||||
|
||||
// Returns the padded_dimensions array for the given Shape. Requires that the
|
||||
// shape is an array and has a dense layout.
|
||||
static absl::Span<const int64> PaddedDimensions(const Shape& shape);
|
||||
|
||||
// Returns the given index of the padded_dimensions array for the given Shape.
|
||||
// Requires that the shape is an array and has a dense layout.
|
||||
static int64 PaddedDimension(const Shape& shape, int64 index);
|
||||
|
||||
// Returns the padding_value for the given Shape. Requires that the shape is
|
||||
// an array and has a dense layout.
|
||||
static PaddingValue GetPaddingValue(const Shape& shape);
|
||||
|
||||
// Returns whether the given Shape is an array (i.e. not a tuple) and has a
|
||||
// sparse format layout.
|
||||
static bool IsSparseArray(const Shape& shape);
|
||||
|
@ -304,30 +304,6 @@ TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) {
|
||||
shape.tuple_shapes(1).layout()));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, IsPadded) {
|
||||
Shape shape_without_layout = ShapeUtil::MakeShape(F32, {2, 3, 4});
|
||||
LayoutUtil::ClearLayout(&shape_without_layout);
|
||||
EXPECT_FALSE(LayoutUtil::IsPadded(shape_without_layout));
|
||||
|
||||
Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4});
|
||||
LayoutUtil::SetToDefaultLayout(&shape_with_layout);
|
||||
EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_layout));
|
||||
|
||||
// Add padding equal to the dimension sizes. In this case the padding is a
|
||||
// nop.
|
||||
Shape shape_with_degenerate_padding = ShapeUtil::MakeShape(F32, {2, 3, 4});
|
||||
shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(2);
|
||||
shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(3);
|
||||
shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(4);
|
||||
EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_degenerate_padding));
|
||||
|
||||
Shape shape_with_padding = ShapeUtil::MakeShape(F32, {2, 3, 4});
|
||||
shape_with_padding.mutable_layout()->add_padded_dimensions(2);
|
||||
shape_with_padding.mutable_layout()->add_padded_dimensions(14);
|
||||
shape_with_padding.mutable_layout()->add_padded_dimensions(42);
|
||||
EXPECT_TRUE(LayoutUtil::IsPadded(shape_with_padding));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) {
|
||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
|
||||
LayoutUtil::GetDefaultLayoutForR2()));
|
||||
|
@ -1075,12 +1075,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
|
||||
|
||||
auto element_to_string = [&](absl::Span<const int64> indices) -> string {
|
||||
PrimitiveType element_type = subshape.element_type();
|
||||
if (element_type == PRED) {
|
||||
// We display predicates in a densely packed form.
|
||||
return literal.Get<bool>(indices, shape_index) ? "1" : "0";
|
||||
}
|
||||
return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
|
||||
literal.GetAsString(indices, shape_index);
|
||||
// We display predicates as 0s and 1s so that the string is more dense.
|
||||
string elem = element_type == PRED
|
||||
? literal.Get<bool>(indices, shape_index) ? "1" : "0"
|
||||
: literal.GetAsString(indices, shape_index);
|
||||
return ((!indices.empty() && indices.back() > 0) ? ", " : "") + elem;
|
||||
};
|
||||
|
||||
if (ShapeUtil::Rank(subshape) == 0) {
|
||||
|
@ -34,16 +34,22 @@ namespace xla {
|
||||
namespace literal_comparison {
|
||||
namespace {
|
||||
|
||||
// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
|
||||
// able to transparently access the raw 16-bit value contained within.
|
||||
template <typename T>
|
||||
T GetRawValue(T val) {
|
||||
return val;
|
||||
}
|
||||
uint16 GetRawValue(Eigen::half val) { return val.x; }
|
||||
|
||||
// Helper function for comparing a floating point type, FloatT, bitwise equal
|
||||
// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
|
||||
// -- on miscompare, a nice error message is given in the AssertionFailure.
|
||||
template <typename FloatT, typename UnsignedT>
|
||||
Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
|
||||
absl::Span<const int64> multi_index) {
|
||||
// TODO(b/118627822): These are unsafe bit_casts because Eigen::Half is not
|
||||
// trivially copyable.
|
||||
auto ulhs = absl::bit_cast<UnsignedT>(lhs);
|
||||
auto urhs = absl::bit_cast<UnsignedT>(rhs);
|
||||
auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
|
||||
auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
|
||||
auto lhs_double = static_cast<double>(lhs);
|
||||
auto rhs_double = static_cast<double>(rhs);
|
||||
if (ulhs != urhs) {
|
||||
|
@ -133,7 +133,7 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
|
||||
|
||||
TEST_F(LiteralUtilTest, LiteralVectorToString) {
|
||||
auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
|
||||
EXPECT_EQ("{101}", pred_vec.ToString());
|
||||
EXPECT_EQ("{1, 0, 1}", pred_vec.ToString());
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, R2ToString) {
|
||||
|
@ -184,8 +184,9 @@ StatusOr<LocalShapedBuffer*> LocalShapedBufferTuple::Release(int i) {
|
||||
|
||||
int64 LocalShapedBufferTuple::size() const { return elements_.size(); }
|
||||
|
||||
XrtAllocation::XrtAllocation(int64 handle, Shape shape)
|
||||
: handle_(handle), shape_(shape) {}
|
||||
XrtAllocation::XrtAllocation(int64 handle, Shape shape,
|
||||
const string& session_target)
|
||||
: handle_(handle), shape_(shape), session_target_(session_target) {}
|
||||
|
||||
XrtAllocation::~XrtAllocation() {
|
||||
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
|
||||
@ -198,7 +199,7 @@ XrtAllocation::~XrtAllocation() {
|
||||
return;
|
||||
}
|
||||
|
||||
tensorflow::ClientSession session(root, "local");
|
||||
tensorflow::ClientSession session(root, session_target_);
|
||||
tensorflow::ClientSession::FeedType inputs;
|
||||
inputs.insert({allocation_handle, handle()});
|
||||
std::vector<tensorflow::Tensor> outputs;
|
||||
@ -210,7 +211,8 @@ XrtAllocation::~XrtAllocation() {
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<XrtAllocation*> XrtAllocation::FromLiteral(const Literal& argument) {
|
||||
StatusOr<XrtAllocation*> XrtAllocation::FromLiteral(
|
||||
const Literal& argument, const string& session_target) {
|
||||
xrt::XLAAllocation alloc;
|
||||
alloc.set_device_ordinal(0);
|
||||
*alloc.mutable_value() = argument.ToProto();
|
||||
@ -221,14 +223,14 @@ StatusOr<XrtAllocation*> XrtAllocation::FromLiteral(const Literal& argument) {
|
||||
auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string);
|
||||
TF_RETURN_IF_ERROR(root.status());
|
||||
|
||||
tensorflow::ClientSession session(root, "local");
|
||||
tensorflow::ClientSession session(root, session_target);
|
||||
tensorflow::ClientSession::FeedType inputs;
|
||||
inputs.insert({literal_string, alloc.SerializeAsString()});
|
||||
std::vector<tensorflow::Tensor> outputs;
|
||||
TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs));
|
||||
|
||||
int64 handle = outputs[0].scalar<int64>()();
|
||||
return new XrtAllocation(handle, argument.shape());
|
||||
return new XrtAllocation(handle, argument.shape(), session_target);
|
||||
}
|
||||
|
||||
const int64 XrtAllocation::handle() const { return handle_; }
|
||||
@ -242,7 +244,7 @@ StatusOr<Literal> XrtAllocation::ToLiteral() const {
|
||||
auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle);
|
||||
TF_RETURN_IF_ERROR(root.status());
|
||||
|
||||
tensorflow::ClientSession session(root, "local");
|
||||
tensorflow::ClientSession session(root, session_target_);
|
||||
tensorflow::ClientSession::FeedType inputs;
|
||||
inputs.insert({allocation_handle, handle()});
|
||||
std::vector<tensorflow::Tensor> outputs;
|
||||
@ -357,8 +359,11 @@ static StatusOr<Shape> GetReturnValueShape(const XlaComputation& computation) {
|
||||
}
|
||||
|
||||
CompiledXrtComputation::CompiledXrtComputation(
|
||||
const ProgramShape& program_shape, int64 handle)
|
||||
: program_shape_(program_shape), handle_(handle) {}
|
||||
const ProgramShape& program_shape, int64 handle,
|
||||
const string& session_target)
|
||||
: program_shape_(program_shape),
|
||||
handle_(handle),
|
||||
session_target_(session_target) {}
|
||||
|
||||
CompiledXrtComputation::~CompiledXrtComputation() {
|
||||
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
|
||||
@ -371,7 +376,7 @@ CompiledXrtComputation::~CompiledXrtComputation() {
|
||||
return;
|
||||
}
|
||||
|
||||
tensorflow::ClientSession session(root, "local");
|
||||
tensorflow::ClientSession session(root, session_target_);
|
||||
tensorflow::ClientSession::FeedType inputs;
|
||||
inputs.insert({computation_handle, handle()});
|
||||
std::vector<tensorflow::Tensor> outputs;
|
||||
@ -407,7 +412,7 @@ StatusOr<XrtAllocation*> CompiledXrtComputation::Execute(
|
||||
e.set_release_input_handles(false);
|
||||
e.set_release_compilation_handle(false);
|
||||
|
||||
tensorflow::ClientSession session(root, "local");
|
||||
tensorflow::ClientSession session(root, session_target_);
|
||||
tensorflow::ClientSession::FeedType inputs;
|
||||
for (int i = 0; i < arguments.size(); ++i) {
|
||||
inputs.insert({arguments[i], argument_handles[i]->handle()});
|
||||
@ -418,7 +423,7 @@ StatusOr<XrtAllocation*> CompiledXrtComputation::Execute(
|
||||
TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs));
|
||||
|
||||
int64 output = outputs[0].scalar<int64>()();
|
||||
return new XrtAllocation(output, program_shape().result());
|
||||
return new XrtAllocation(output, program_shape().result(), session_target_);
|
||||
}
|
||||
|
||||
const ProgramShape& CompiledXrtComputation::program_shape() const {
|
||||
@ -451,7 +456,7 @@ StatusOr<CompiledLocalComputation*> LocalComputation::Compile(
|
||||
}
|
||||
|
||||
StatusOr<CompiledXrtComputation*> LocalComputation::CompileForXrt(
|
||||
const std::vector<Shape>& argument_shapes) {
|
||||
const std::vector<Shape>& argument_shapes, const string& session_target) {
|
||||
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
|
||||
auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING);
|
||||
auto compile = tensorflow::ops::XRTCompile(root, program);
|
||||
@ -468,7 +473,7 @@ StatusOr<CompiledXrtComputation*> LocalComputation::CompileForXrt(
|
||||
auto snapshot = computation().Snapshot().ValueOrDie();
|
||||
*c.mutable_hlo_snapshot() = *snapshot;
|
||||
|
||||
tensorflow::ClientSession session(root, "local");
|
||||
tensorflow::ClientSession session(root, session_target);
|
||||
tensorflow::ClientSession::FeedType inputs;
|
||||
inputs.insert({program, c.SerializeAsString()});
|
||||
std::vector<tensorflow::Tensor> outputs;
|
||||
@ -477,7 +482,7 @@ StatusOr<CompiledXrtComputation*> LocalComputation::CompileForXrt(
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||
computation().GetProgramShape());
|
||||
int64 handle = outputs[0].scalar<int64>()();
|
||||
return new CompiledXrtComputation(program_shape, handle);
|
||||
return new CompiledXrtComputation(program_shape, handle, session_target);
|
||||
}
|
||||
|
||||
const XlaComputation& LocalComputation::computation() const {
|
||||
@ -929,7 +934,7 @@ StatusOr<LocalShapedBufferTuple*> DestructureLocalShapedBufferTuple(
|
||||
}
|
||||
|
||||
StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
|
||||
XrtAllocation* allocation) {
|
||||
XrtAllocation* allocation, const string& session_target) {
|
||||
const Shape& tuple_shape = allocation->shape();
|
||||
|
||||
if (!ShapeUtil::IsTuple(tuple_shape)) {
|
||||
@ -945,7 +950,7 @@ StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
|
||||
auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index);
|
||||
TF_RETURN_IF_ERROR(root.status());
|
||||
|
||||
tensorflow::ClientSession session(root, "local");
|
||||
tensorflow::ClientSession session(root, session_target);
|
||||
tensorflow::ClientSession::FeedType inputs;
|
||||
std::vector<XrtAllocation*> results;
|
||||
for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
|
||||
@ -964,7 +969,8 @@ StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
|
||||
const int64 subtuple_handle = outputs[0].scalar<int64>()();
|
||||
const Shape& subtuple_shape =
|
||||
ShapeUtil::GetTupleElementShape(tuple_shape, i);
|
||||
results.push_back(new XrtAllocation(subtuple_handle, subtuple_shape));
|
||||
results.push_back(
|
||||
new XrtAllocation(subtuple_handle, subtuple_shape, session_target));
|
||||
}
|
||||
return new XrtAllocationTuple(std::move(results));
|
||||
}
|
||||
|
@ -16,6 +16,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
@ -110,17 +113,22 @@ StatusOr<LocalShapedBufferTuple*> DestructureLocalShapedBufferTuple(
|
||||
// graph, and an XLA shape to track the referent's shape.
|
||||
class XrtAllocation {
|
||||
public:
|
||||
static StatusOr<XrtAllocation*> FromLiteral(const Literal& argument);
|
||||
// Accepts a `session_target` argument, used in constructing the
|
||||
// `tensorflow::ClientSession` instance in which allocation and deallocation
|
||||
// graphs are run.
|
||||
static StatusOr<XrtAllocation*> FromLiteral(const Literal& argument,
|
||||
const string& session_target);
|
||||
|
||||
XrtAllocation(int64 handle, Shape shape);
|
||||
XrtAllocation(int64 handle, Shape shape, const string& session_target);
|
||||
~XrtAllocation();
|
||||
StatusOr<Literal> ToLiteral() const;
|
||||
const Shape& shape() const;
|
||||
const int64 handle() const;
|
||||
|
||||
private:
|
||||
int64 handle_;
|
||||
Shape shape_;
|
||||
const int64 handle_;
|
||||
const Shape shape_;
|
||||
const string session_target_;
|
||||
};
|
||||
|
||||
// Result of a tuple destructuring operation on an XrtAllocation.
|
||||
@ -145,8 +153,12 @@ class XrtAllocationTuple {
|
||||
|
||||
// Destructures a tuple-valued XrtAllocation into its constitutent elements
|
||||
// in XrtAllocationTuple form.
|
||||
//
|
||||
// Accepts a `session_target` argument, used in constructing the
|
||||
// `tensorflow::ClientSession` instance in which the sub-tupling graph is run,
|
||||
// and passed along in constructing each constituent XrtAllocation.
|
||||
StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
|
||||
XrtAllocation* allocation);
|
||||
XrtAllocation* allocation, const string& session_target);
|
||||
|
||||
// Represents a compiled computation that can be executed given handles to
|
||||
// device-allocated literals. Specifically, wraps an XLA LocalExecutable.
|
||||
@ -165,7 +177,10 @@ class CompiledLocalComputation {
|
||||
// device-allocated literals. Specifically, wraps an XRT computation handle.
|
||||
class CompiledXrtComputation {
|
||||
public:
|
||||
CompiledXrtComputation(const ProgramShape& program_shape, int64 handle);
|
||||
// Accepts a `session_target` argument, used in constructing the
|
||||
// `tensorflow::ClientSession` instance in which the execution graph is run.
|
||||
CompiledXrtComputation(const ProgramShape& program_shape, int64 handle,
|
||||
const string& session_target);
|
||||
~CompiledXrtComputation();
|
||||
|
||||
StatusOr<XrtAllocation*> Execute(
|
||||
@ -175,8 +190,9 @@ class CompiledXrtComputation {
|
||||
int64 handle() const;
|
||||
|
||||
private:
|
||||
ProgramShape program_shape_;
|
||||
int64 handle_;
|
||||
const ProgramShape program_shape_;
|
||||
const int64 handle_;
|
||||
const string session_target_;
|
||||
};
|
||||
|
||||
// Wraps a XlaComputation produced by a LocalComputationBuilder. The
|
||||
@ -191,8 +207,10 @@ class LocalComputation {
|
||||
const std::vector<Shape>& argument_shapes,
|
||||
const ExecutableBuildOptions* build_options);
|
||||
|
||||
// Accepts a `session_target` argument, used in constructing the
|
||||
// `tensorflow::ClientSession` instance in which the compilation graph is run.
|
||||
StatusOr<CompiledXrtComputation*> CompileForXrt(
|
||||
const std::vector<Shape>& argument_shapes);
|
||||
const std::vector<Shape>& argument_shapes, const string& session_target);
|
||||
|
||||
const XlaComputation& computation() const;
|
||||
|
||||
|
@ -451,6 +451,10 @@ tensorflow::ImportNumpy();
|
||||
|
||||
// Shape
|
||||
|
||||
%typemap(out) const Shape& {
|
||||
$result = numpy::PyShapeInfoFromXlaShape(*$1);
|
||||
}
|
||||
|
||||
%typemap(out) StatusOr<Shape> {
|
||||
if ($1.ok()) {
|
||||
$result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie());
|
||||
@ -980,6 +984,7 @@ tensorflow::ImportNumpy();
|
||||
%unignore xla::swig::LocalShapedBuffer;
|
||||
%unignore xla::swig::LocalShapedBuffer::FromLiteral;
|
||||
%unignore xla::swig::LocalShapedBuffer::ToLiteral;
|
||||
%unignore xla::swig::LocalShapedBuffer::shape;
|
||||
%unignore xla::swig::LocalShapedBufferTuple;
|
||||
%unignore xla::swig::LocalShapedBufferTuple::Release;
|
||||
%unignore xla::swig::LocalShapedBufferTuple::size;
|
||||
|
@ -51,6 +51,10 @@ class BackendType(enum.Enum):
|
||||
XRT = 2
|
||||
|
||||
|
||||
BackendSpec = collections.namedtuple('Backend', ('backend_type', 'target'))
|
||||
XLA_LOCAL_BACKEND = BackendSpec(BackendType.XLA_LOCAL, 'local')
|
||||
|
||||
|
||||
def OpMetadataToProto(pyobj):
|
||||
proto = xla_data_pb2.OpMetadata()
|
||||
for field in _OP_METADATA_FIELDS:
|
||||
@ -211,17 +215,17 @@ class LocalBuffer(object):
|
||||
def __init__(self, c_buffer, backend):
|
||||
self.c_buffer = c_buffer
|
||||
self._backend = backend
|
||||
if backend == BackendType.XRT:
|
||||
if backend.backend_type == BackendType.XRT:
|
||||
self._delete = c_api.DeleteXrtAllocation
|
||||
else:
|
||||
self._delete = c_api.DeleteLocalShapedBuffer
|
||||
|
||||
@staticmethod
|
||||
def from_pyval(pyval, backend=BackendType.XLA_LOCAL):
|
||||
def from_pyval(pyval, backend=XLA_LOCAL_BACKEND):
|
||||
"""Allocate and copy to XLA the given python value."""
|
||||
pyval = require_numpy_array_layout(pyval)
|
||||
if backend == BackendType.XRT:
|
||||
cbuf = c_api.XrtAllocation.FromLiteral(pyval)
|
||||
if backend.backend_type == BackendType.XRT:
|
||||
cbuf = c_api.XrtAllocation.FromLiteral(pyval, backend.target)
|
||||
else:
|
||||
cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None)
|
||||
return LocalBuffer(cbuf, backend)
|
||||
@ -229,6 +233,9 @@ class LocalBuffer(object):
|
||||
def to_py(self):
|
||||
return self.c_buffer.ToLiteral()
|
||||
|
||||
def shape(self):
|
||||
return _wrap_shape(self.c_buffer.shape())
|
||||
|
||||
def delete(self):
|
||||
if self.c_buffer is not None:
|
||||
self._delete(self.c_buffer)
|
||||
@ -237,8 +244,9 @@ class LocalBuffer(object):
|
||||
def destructure(self):
|
||||
"""Assuming a tuple buffer, unpack it into constituent tuple elements."""
|
||||
assert self.c_buffer is not None
|
||||
if self._backend == BackendType.XRT:
|
||||
result = c_api.DestructureXrtAllocationTuple(self.c_buffer)
|
||||
if self._backend.backend_type == BackendType.XRT:
|
||||
result = c_api.DestructureXrtAllocationTuple(self.c_buffer,
|
||||
self._backend.target)
|
||||
else:
|
||||
result = c_api.DestructureLocalShapedBufferTuple(self.c_buffer)
|
||||
self.delete()
|
||||
@ -467,14 +475,14 @@ class LocalComputation(object):
|
||||
ComputationBuilder methods.
|
||||
"""
|
||||
|
||||
def __init__(self, c_computation, is_compiled, backend=BackendType.XLA_LOCAL):
|
||||
def __init__(self, c_computation, is_compiled, backend=XLA_LOCAL_BACKEND):
|
||||
self._c_computation = c_computation
|
||||
self._backend = backend
|
||||
self._is_compiled = is_compiled
|
||||
|
||||
# Ensure a reference to C-based destructor for use in __del__.
|
||||
if is_compiled:
|
||||
if backend == BackendType.XRT:
|
||||
if backend.backend_type == BackendType.XRT:
|
||||
assert isinstance(c_computation, c_api.CompiledXrtComputation)
|
||||
self._delete = c_api.DeleteCompiledXrtComputation
|
||||
else:
|
||||
@ -535,8 +543,8 @@ class LocalComputation(object):
|
||||
|
||||
compile_options = compile_options or CompileOptions()
|
||||
compile_options.result_shape = result_shape
|
||||
if self._backend == BackendType.XRT:
|
||||
c = self.computation.CompileForXrt(argument_shapes)
|
||||
if self._backend.backend_type == BackendType.XRT:
|
||||
c = self.computation.CompileForXrt(argument_shapes, self._backend.target)
|
||||
else:
|
||||
c = self.computation.Compile(argument_shapes, compile_options)
|
||||
return LocalComputation(c, is_compiled=True, backend=self._backend)
|
||||
@ -590,7 +598,7 @@ class ComputationBuilder(object):
|
||||
self._client = c_api.LocalComputationBuilder(name.encode('utf8'))
|
||||
self._parameter_numbering = itertools.count()
|
||||
|
||||
def Build(self, root=None, backend=BackendType.XLA_LOCAL):
|
||||
def Build(self, root=None, backend=XLA_LOCAL_BACKEND):
|
||||
if root is not None:
|
||||
return LocalComputation(
|
||||
self._client.BuildWithRoot(root), is_compiled=False, backend=backend)
|
||||
|
@ -439,6 +439,13 @@ class LocalBufferTest(LocalComputationTest):
|
||||
np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0])
|
||||
np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1])
|
||||
|
||||
def testShape(self):
|
||||
pyval = np.array([[1., 2.]], np.float32)
|
||||
local_buffer = xla_client.LocalBuffer.from_pyval(pyval)
|
||||
xla_shape = local_buffer.shape()
|
||||
self.assertEqual(xla_shape.dimensions(), (1, 2,))
|
||||
self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
|
||||
|
||||
|
||||
class SingleOpTest(LocalComputationTest):
|
||||
"""Tests for single ops.
|
||||
|
@ -323,7 +323,6 @@ cc_library(
|
||||
":hlo_casting_utils",
|
||||
":hlo_module_config",
|
||||
":hlo_proto",
|
||||
":hlo_reachability",
|
||||
":name_uniquer",
|
||||
"//tensorflow/compiler/xla:array",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
@ -402,6 +401,7 @@ cc_library(
|
||||
srcs = ["hlo_reachability.cc"],
|
||||
hdrs = ["hlo_reachability.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
@ -1103,6 +1103,7 @@ cc_library(
|
||||
":hlo",
|
||||
":hlo_dataflow_analysis",
|
||||
":hlo_proto",
|
||||
":hlo_reachability",
|
||||
":hlo_value",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -1362,6 +1363,7 @@ cc_library(
|
||||
":fusion_queue",
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
":hlo_reachability",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
@ -1387,6 +1389,7 @@ cc_library(
|
||||
srcs = ["multi_output_fusion.cc"],
|
||||
hdrs = ["multi_output_fusion.h"],
|
||||
deps = [
|
||||
":hlo_reachability",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
@ -3241,6 +3244,7 @@ cc_library(
|
||||
":hlo_profile_printer_data",
|
||||
":human_readable_profile_builder",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -3365,6 +3369,7 @@ cc_library(
|
||||
":bfloat16_normalization",
|
||||
":defuser",
|
||||
":hlo",
|
||||
":hlo_memory_scheduler",
|
||||
":hlo_pass",
|
||||
":hlo_pass_pipeline",
|
||||
":implicit_broadcast_remover",
|
||||
@ -3448,6 +3453,7 @@ tf_cc_test(
|
||||
":hlo_casting_utils",
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -306,6 +306,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
// Tries to use a kDot in place of the given convolution.
|
||||
StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
|
||||
|
||||
// Tries to simplify a slice(pad(...)) where the result of the slice is a
|
||||
// scalar.
|
||||
StatusOr<bool> TrySimplifySliceOfPad(HloInstruction* slice);
|
||||
|
||||
// Current HloComputation instance the AlgebraicSimplifierVisitor is
|
||||
// traversing.
|
||||
HloComputation* computation_;
|
||||
@ -1822,6 +1826,62 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifySliceOfPad(
|
||||
HloInstruction* slice) {
|
||||
// Only try to do this for effective scalars. We could do the same for slicing
|
||||
// out larger pieces of padding (replacing with a broadcast of the padding
|
||||
// value), but this is probably not worth it.
|
||||
if (!ShapeUtil::IsEffectiveScalar(slice->shape()) ||
|
||||
slice->operand(0)->opcode() != HloOpcode::kPad) {
|
||||
return false;
|
||||
}
|
||||
|
||||
VLOG(10) << "Trying to simplify scalar slice of pad";
|
||||
// Check there's no internal padding. Again, we could handle that too, since
|
||||
// everything is statically known, but it's not worth it.
|
||||
auto pad = Cast<HloPadInstruction>(slice->mutable_operand(0));
|
||||
auto padding_config = pad->padding_config();
|
||||
int64 rank = padding_config.dimensions_size();
|
||||
if (HasInteriorPadding(padding_config)) {
|
||||
VLOG(10) << "Not folding scalar slice of pad, pad has interior padding";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check whether the scalar we're slicing out falls into the padding.
|
||||
bool in_padding = [&]() {
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
int64 start = slice->slice_starts(i);
|
||||
int64 low = padding_config.dimensions(i).edge_padding_low();
|
||||
int64 data = pad->operand(0)->shape().dimensions(i);
|
||||
if (start >= low && start < low + data) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}();
|
||||
|
||||
if (in_padding) {
|
||||
VLOG(10) << "Folding scalar slice of pad into padding value";
|
||||
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
|
||||
slice, HloInstruction::CreateReshape(slice->shape(),
|
||||
pad->mutable_padding_value())));
|
||||
return true;
|
||||
} else {
|
||||
// We already know the output of the slice is scalar. If the padded
|
||||
// value is scalar, and it's not in the padding, then it's exactly the
|
||||
// output value.
|
||||
bool replaced =
|
||||
ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0));
|
||||
if (replaced) {
|
||||
VLOG(10) << "Folding scalar slice of pad into padded value";
|
||||
} else {
|
||||
VLOG(10) << "Not folding scalar slice of pad into padded value as they "
|
||||
"have different shapes.";
|
||||
}
|
||||
return replaced;
|
||||
}
|
||||
}
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
|
||||
// Delete no-op slices, i.e. where shape = operand shape.
|
||||
if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) {
|
||||
@ -1846,6 +1906,12 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
|
||||
slice->shape(), operand_slice->mutable_operand(0),
|
||||
new_slice_starts, new_slice_limits, slice->slice_strides()));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifySliceOfPad(slice));
|
||||
if (replaced) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -3163,6 +3163,92 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) {
|
||||
EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY test {
|
||||
param = f32[3,4] parameter(0)
|
||||
constant = f32[] constant(0.0)
|
||||
pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
|
||||
ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[2:3],[0:1]}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto module,
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
||||
|
||||
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
|
||||
bitcasting_callback());
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Reshape(op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY test {
|
||||
param = f32[3,4] parameter(0)
|
||||
constant = f32[] constant(0.0)
|
||||
pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
|
||||
ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[6:7],[9:10]}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto module,
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
||||
|
||||
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
|
||||
bitcasting_callback());
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Reshape(op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY test {
|
||||
param = f32[3,4] parameter(0)
|
||||
constant = f32[] constant(0.0)
|
||||
pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
|
||||
ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto module,
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
||||
|
||||
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
|
||||
bitcasting_callback());
|
||||
EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY test {
|
||||
param = f32[1,1] parameter(0)
|
||||
constant = f32[] constant(0.0)
|
||||
pad = f32[8,10] pad(f32[1,1] param, f32[] constant), padding=3_4x4_5
|
||||
ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[3:4],[4:5]}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto module,
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
||||
|
||||
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
|
||||
bitcasting_callback());
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Parameter());
|
||||
}
|
||||
|
||||
struct PadReduceWindowEffectiveBroadcastCase {
|
||||
std::vector<int64> input_spatials;
|
||||
std::vector<int64> symmetric_pad_spatials;
|
||||
|
@ -151,15 +151,10 @@ Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions(
|
||||
|
||||
Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
// Do not fold BF16 conversions for instructions related to tuples, entry and
|
||||
// exit of a computation, fusion, convert, and control flow.
|
||||
// exit of a computation, fusion, convert, side-effecting instructions and
|
||||
// control flow.
|
||||
if (hlo->opcode() == HloOpcode::kTuple || //
|
||||
hlo->opcode() == HloOpcode::kGetTupleElement || //
|
||||
hlo->opcode() == HloOpcode::kInfeed || //
|
||||
hlo->opcode() == HloOpcode::kOutfeed || //
|
||||
hlo->opcode() == HloOpcode::kSend || //
|
||||
hlo->opcode() == HloOpcode::kSendDone || //
|
||||
hlo->opcode() == HloOpcode::kRecv || //
|
||||
hlo->opcode() == HloOpcode::kRecvDone || //
|
||||
hlo->opcode() == HloOpcode::kConstant || //
|
||||
hlo->opcode() == HloOpcode::kParameter || //
|
||||
hlo->opcode() == HloOpcode::kFusion || //
|
||||
@ -167,7 +162,8 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
hlo->opcode() == HloOpcode::kCall || //
|
||||
hlo->opcode() == HloOpcode::kCustomCall || //
|
||||
hlo->opcode() == HloOpcode::kWhile || //
|
||||
hlo->opcode() == HloOpcode::kConditional) {
|
||||
hlo->opcode() == HloOpcode::kConditional || //
|
||||
hlo->HasSideEffectNoRecurse()) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (hlo == computation_->root_instruction() &&
|
||||
@ -182,6 +178,10 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
|
||||
Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum(
|
||||
HloInstruction* crs) {
|
||||
if (crs->IsCrossModuleAllReduce()) {
|
||||
// Cross-module all-reduce has side effect.
|
||||
return Status::OK();
|
||||
}
|
||||
// First use DefaultAction() to handle the operands. It can't handle
|
||||
// tuple-shaped output.
|
||||
TF_RETURN_IF_ERROR(DefaultAction(crs));
|
||||
|
@ -346,11 +346,9 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
|
||||
|
||||
Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
// Do not change instructions related to entry and exit of a computation,
|
||||
// tuples, fusion, convert, and control flow.
|
||||
// tuples, fusion, convert, side-effecting instructions, and control flow.
|
||||
if (hlo->opcode() == HloOpcode::kTuple || //
|
||||
hlo->opcode() == HloOpcode::kGetTupleElement || //
|
||||
hlo->opcode() == HloOpcode::kInfeed || //
|
||||
hlo->opcode() == HloOpcode::kOutfeed || //
|
||||
hlo->opcode() == HloOpcode::kConstant || //
|
||||
hlo->opcode() == HloOpcode::kParameter || //
|
||||
hlo->opcode() == HloOpcode::kFusion || //
|
||||
@ -358,7 +356,8 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
hlo->opcode() == HloOpcode::kCall || //
|
||||
hlo->opcode() == HloOpcode::kCustomCall || //
|
||||
hlo->opcode() == HloOpcode::kWhile || //
|
||||
hlo->opcode() == HloOpcode::kConditional) {
|
||||
hlo->opcode() == HloOpcode::kConditional || //
|
||||
hlo->HasSideEffectNoRecurse()) {
|
||||
return Status::OK();
|
||||
}
|
||||
// TODO(b/112040122): Correctly normalize variadic reduce.
|
||||
|
@ -236,6 +236,10 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
|
||||
// the end of the BFloat16Propagation pass.
|
||||
continue;
|
||||
}
|
||||
if (use.instruction->HasSideEffectNoRecurse()) {
|
||||
// Keep side-effecting instruction's operands unchanged.
|
||||
return false;
|
||||
}
|
||||
// Any visited user that can accept BF16 has already been updated if
|
||||
// necessary, e.g., the output has been changed to BF16 if it propagates
|
||||
// precision, or a called computation's parameters have been changed to
|
||||
@ -329,22 +333,6 @@ void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
|
||||
return;
|
||||
}
|
||||
|
||||
// Do not change precision for instructions related to entry and exit of a
|
||||
// computation, and control flow, because this pass might break the interfaces
|
||||
// or assumptions for them.
|
||||
if (hlo->opcode() == HloOpcode::kInfeed || //
|
||||
hlo->opcode() == HloOpcode::kOutfeed || //
|
||||
hlo->opcode() == HloOpcode::kSend || //
|
||||
hlo->opcode() == HloOpcode::kSendDone || //
|
||||
hlo->opcode() == HloOpcode::kRecv || //
|
||||
hlo->opcode() == HloOpcode::kRecvDone || //
|
||||
hlo->opcode() == HloOpcode::kCustomCall || //
|
||||
hlo->opcode() == HloOpcode::kCall || //
|
||||
hlo->opcode() == HloOpcode::kConditional || //
|
||||
(hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Prevent root instructions from having their output modified by recording
|
||||
// all F32 output values as needing to stay as F32.
|
||||
CHECK(hlo->parent() != nullptr);
|
||||
@ -366,6 +354,17 @@ void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
|
||||
return;
|
||||
}
|
||||
|
||||
// Do not change precision for instructions related to entry and exit of a
|
||||
// computation, side-effecting instructions, and control flow, because this
|
||||
// pass might break the interfaces or assumptions for them.
|
||||
if (hlo->opcode() == HloOpcode::kCustomCall || //
|
||||
hlo->opcode() == HloOpcode::kCall || //
|
||||
hlo->opcode() == HloOpcode::kConditional || //
|
||||
hlo->HasSideEffectNoRecurse() || //
|
||||
(hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!ContainsKey(consider_using_bfloat16_, hlo)) {
|
||||
return;
|
||||
}
|
||||
|
@ -136,6 +136,40 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) {
|
||||
EXPECT_FALSE(OutputsBF16(c));
|
||||
}
|
||||
|
||||
// Tests that side-effecting all-reduce should not be changed.
|
||||
TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
|
||||
HloInstruction* a =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
|
||||
HloInstruction* b =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
|
||||
auto rb = HloComputation::Builder(TestName());
|
||||
rb.AddInstruction(HloInstruction::CreateBinary(
|
||||
shape, HloOpcode::kAdd,
|
||||
rb.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")),
|
||||
rb.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"))));
|
||||
auto reduction = module->AddEmbeddedComputation(rb.Build());
|
||||
HloInstruction* all_reduce =
|
||||
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
|
||||
ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction,
|
||||
/*replica_groups=*/{}, /*barrier=*/"", /*all_reduce_id=*/1));
|
||||
HloInstruction* gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(shape, all_reduce, 0));
|
||||
HloInstruction* gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(shape, all_reduce, 1));
|
||||
HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1));
|
||||
HloInstruction* root = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
|
||||
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_FALSE(PropagatePrecision(module.get()));
|
||||
EXPECT_EQ(computation->root_instruction(), root);
|
||||
}
|
||||
|
||||
// Tests that if a constant is converted to BF16 then its literal must also be
|
||||
// converted.
|
||||
TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
|
||||
|
@ -502,8 +502,8 @@ Status CreateHloProfilingArtifacts(
|
||||
|
||||
HloCostAnalysis cost_analysis(shape_size_bytes);
|
||||
TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis));
|
||||
*hlo_profile_printer_data =
|
||||
CreateHloProfilePrinterData(**hlo_profile_index_map, cost_analysis);
|
||||
*hlo_profile_printer_data = CreateHloProfilePrinterData(
|
||||
**hlo_profile_index_map, cost_analysis, entry_computation.name());
|
||||
*computation_to_profile_idx =
|
||||
(*hlo_profile_index_map)->computation_to_profile_idx();
|
||||
|
||||
|
@ -1546,10 +1546,8 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
|
||||
LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0};
|
||||
}
|
||||
|
||||
// Return whether the given shape is a matrix with no padding.
|
||||
static bool IsRank2WithNoPadding(const Shape& shape) {
|
||||
return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);
|
||||
}
|
||||
// Return whether the given shape is rank 2.
|
||||
static bool IsRank2(const Shape& shape) { return ShapeUtil::Rank(shape) == 2; }
|
||||
|
||||
// In a gemm operation where output = lhs * rhs, check whether the given shapes
|
||||
// are valid for the operation.
|
||||
@ -1565,8 +1563,7 @@ static bool AreValidGemmShapes(
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!(IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) &&
|
||||
IsRank2WithNoPadding(output_shape))) {
|
||||
if (!(IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -2206,16 +2206,16 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
||||
VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
|
||||
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
|
||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
|
||||
|
||||
// Delegate to common implementation of fused in-place dynamic-update-slice.
|
||||
auto operands = GetIrArraysForOperandsOf(fusion);
|
||||
return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
|
||||
fusion, operands, GetIrArrayFor(fusion), &elemental_emitter, &b_);
|
||||
fusion, GetGeneratorForOperandIrArrays(fusion), GetIrArrayFor(fusion),
|
||||
&elemental_emitter, &b_);
|
||||
} else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
||||
VLOG(3) << "HandleFusion kLoop";
|
||||
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
|
||||
auto operands = GetIrArraysForOperandsOf(fusion);
|
||||
FusedIrEmitter fused_emitter(operands, &elemental_emitter);
|
||||
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
|
||||
&elemental_emitter);
|
||||
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
|
||||
|
||||
return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator());
|
||||
@ -2415,14 +2415,8 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
|
||||
*failure_reason = "operand has mismatching layouts";
|
||||
return false;
|
||||
}
|
||||
if (LayoutUtil::IsPadded(op->shape())) {
|
||||
*failure_reason = "operand has padded layout";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
CHECK(!LayoutUtil::IsPadded(concatenate->shape()));
|
||||
|
||||
// We split the dimensions into three categories: the dimension over which we
|
||||
// are concatenating (concat_dim), the dimensions that are minor to it
|
||||
// (inner_dims) and the dimensions that are major to it (outer_dims).
|
||||
|
@ -59,6 +59,9 @@ namespace cpu {
|
||||
class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
public IrBuilderMixin<IrEmitter> {
|
||||
public:
|
||||
using GeneratorForOperandIrArrays =
|
||||
std::function<std::vector<llvm_ir::IrArray>()>;
|
||||
|
||||
// Create a new LLVM IR emitter.
|
||||
//
|
||||
// hlo_module: the HLO module we are emitting IR for.
|
||||
@ -208,6 +211,11 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf(
|
||||
const HloInstruction* hlo);
|
||||
|
||||
GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays(
|
||||
HloInstruction* unnested_hlo) {
|
||||
return [=]() { return GetIrArraysForOperandsOf(unnested_hlo); };
|
||||
}
|
||||
|
||||
// Augments IrArray with aliasing information.
|
||||
void AddAliasingInformationToIrArray(const HloInstruction& hlo,
|
||||
llvm_ir::IrArray* array) {
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
|
||||
#include "tensorflow/compiler/xla/service/defuser.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
|
||||
#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h"
|
||||
|
||||
namespace xla {
|
||||
@ -45,6 +46,7 @@ class ControlDepRemover : public HloModulePass {
|
||||
|
||||
Despecializer::Despecializer() : pipeline_("despecializer") {
|
||||
// TODO(b/70588125): Also deal with window reversal in a fast way.
|
||||
pipeline_.AddPass<HloDescheduler>();
|
||||
pipeline_.AddPass<ControlDepRemover>();
|
||||
pipeline_.AddPass<Defuser>();
|
||||
pipeline_.AddPass<ImplicitBroadcastRemover>();
|
||||
|
@ -484,7 +484,7 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
|
||||
op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
|
||||
}
|
||||
|
||||
// Extracted from //learning/brain/google/xla/benchmarks/resnet.py
|
||||
// Extracted from Resnet-50.
|
||||
//
|
||||
// For simplicity, we focus on the column dimension and ignore other dimensions.
|
||||
// We use [?] to represent the shape instead of the content.
|
||||
|
@ -126,9 +126,9 @@ Status RunCudnnConvImpl(CudnnConvParams params,
|
||||
int64 feature_group_count = params.feature_group_count;
|
||||
AlgorithmConfig algorithm = params.algorithm;
|
||||
|
||||
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
|
||||
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm()->algo_id();
|
||||
VLOG(3) << "tensor_ops_enabled: "
|
||||
<< algorithm.algorithm().tensor_ops_enabled();
|
||||
<< algorithm.algorithm()->tensor_ops_enabled();
|
||||
VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind);
|
||||
VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(input_shape);
|
||||
VLOG(3) << "filter shape: " << ShapeUtil::HumanStringWithLayout(filter_shape);
|
||||
@ -302,8 +302,8 @@ Status RunCudnnConvImpl(CudnnConvParams params,
|
||||
if (!stream->ok()) {
|
||||
return InternalError(
|
||||
"Unable to launch convolution with type %s and algorithm (%d, %d)",
|
||||
CudnnConvKindToString(kind), algorithm.algorithm().algo_id(),
|
||||
algorithm.algorithm_no_scratch().algo_id());
|
||||
CudnnConvKindToString(kind), algorithm.algorithm()->algo_id(),
|
||||
algorithm.algorithm_no_scratch()->algo_id());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <functional>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
@ -51,7 +52,8 @@ struct MatrixDescriptor {
|
||||
// rhs_matrix, and stores the result to output_matrix.
|
||||
template <typename Element>
|
||||
bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
|
||||
MatrixDescriptor output_matrix, double alpha, se::Stream* stream) {
|
||||
MatrixDescriptor output_matrix, double alpha, double beta,
|
||||
se::Stream* stream) {
|
||||
DCHECK(!output_matrix.transpose);
|
||||
|
||||
const int64 batch_size = lhs_matrix.batch_size;
|
||||
@ -73,7 +75,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
|
||||
lhs_transpose, rhs_transpose, output_matrix.num_rows,
|
||||
output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
|
||||
lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
|
||||
/*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0,
|
||||
/*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta,
|
||||
&output_data, /*leading dim of output=*/output_matrix.num_rows)
|
||||
.ok();
|
||||
}
|
||||
@ -88,7 +90,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
|
||||
/*alpha=*/alpha, lhs_data,
|
||||
/*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data,
|
||||
/*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride,
|
||||
/*beta=*/0.0, &output_data,
|
||||
/*beta=*/beta, &output_data,
|
||||
/*leading dim of output=*/output_matrix.num_rows, output_stride,
|
||||
batch_size)
|
||||
.ok();
|
||||
@ -112,6 +114,7 @@ template <typename Element>
|
||||
bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
|
||||
MatrixDescriptor rhs_matrix,
|
||||
MatrixDescriptor output_matrix, double alpha,
|
||||
double beta,
|
||||
se::blas::ComputationType computation_type,
|
||||
se::blas::AlgorithmType algorithm, se::Stream* stream,
|
||||
se::blas::ProfileResult* output_profile_result) {
|
||||
@ -138,7 +141,7 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
|
||||
/*alpha=*/static_cast<Element>(alpha), lhs_data,
|
||||
/*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
|
||||
/*leading dim of RHS=*/rhs_matrix.num_rows,
|
||||
/*beta=*/static_cast<Element>(0.0f), &output_data,
|
||||
/*beta=*/static_cast<Element>(beta), &output_data,
|
||||
/*leading dim of output=*/output_matrix.num_rows, computation_type,
|
||||
algorithm, output_profile_result)
|
||||
.ok();
|
||||
@ -153,7 +156,7 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
|
||||
template <typename Element>
|
||||
StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
|
||||
MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
|
||||
MatrixDescriptor output_matrix, double alpha,
|
||||
MatrixDescriptor output_matrix, double alpha, double beta,
|
||||
se::blas::ComputationType computation_type, se::Stream* stream) {
|
||||
std::vector<se::blas::AlgorithmType> algorithms;
|
||||
CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms));
|
||||
@ -166,7 +169,7 @@ StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
|
||||
// non-null ProfileResult, DoGemmWithAlgorithm should always return true,
|
||||
// and the actual success-ness is returned in ProfileResult::is_valid.
|
||||
CHECK(DoGemmWithAlgorithm<Element>(lhs_matrix, rhs_matrix, output_matrix,
|
||||
alpha, computation_type, algorithm,
|
||||
alpha, beta, computation_type, algorithm,
|
||||
stream, &profile_result));
|
||||
|
||||
if (profile_result.is_valid()) {
|
||||
@ -263,8 +266,9 @@ DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) {
|
||||
}
|
||||
CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion);
|
||||
CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput);
|
||||
CHECK_EQ(hlo_instruction.fused_expression_root()->opcode(),
|
||||
HloOpcode::kMultiply);
|
||||
CHECK(hlo_instruction.fused_expression_root()->opcode() == HloOpcode::kAdd ||
|
||||
hlo_instruction.fused_expression_root()->opcode() ==
|
||||
HloOpcode::kMultiply);
|
||||
// Try to find the dot inside the output fusion node.
|
||||
const HloInstruction* dot =
|
||||
hlo_instruction.fused_expression_root()->operand(0);
|
||||
@ -282,8 +286,9 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
|
||||
const BufferAllocation::Slice& rhs_buffer,
|
||||
const BufferAllocation::Slice& output_buffer,
|
||||
const Shape& lhs_shape, const Shape& rhs_shape,
|
||||
const Shape& output_shape, double alpha,
|
||||
const HloInstruction* hlo_instruction)
|
||||
const Shape& output_shape, double alpha, double beta,
|
||||
const HloInstruction* hlo_instruction,
|
||||
bool implements_whole_instruction)
|
||||
: Thunk(Kind::kGemm, hlo_instruction),
|
||||
lhs_buffer_(lhs_buffer),
|
||||
rhs_buffer_(rhs_buffer),
|
||||
@ -291,7 +296,9 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
|
||||
lhs_shape_(lhs_shape),
|
||||
rhs_shape_(rhs_shape),
|
||||
output_shape_(output_shape),
|
||||
alpha_(alpha) {}
|
||||
alpha_(alpha),
|
||||
beta_(beta),
|
||||
implements_whole_instruction_(implements_whole_instruction) {}
|
||||
|
||||
Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
@ -386,7 +393,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
// TODO(b/112111608): Implement auto tune for batched gemm.
|
||||
if (batch_size != 1) {
|
||||
return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
|
||||
alpha_, stream);
|
||||
alpha_, beta_, stream);
|
||||
}
|
||||
|
||||
auto thunk_name = [&] {
|
||||
@ -398,9 +405,27 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
auto autotune_it = autotune_results_.find(device_name);
|
||||
if (autotune_it == autotune_results_.end()) {
|
||||
VLOG(3) << "Starting autotune of GemmThunk " << thunk_name();
|
||||
StatusOr<se::blas::AlgorithmType> best_algorithm =
|
||||
GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
|
||||
alpha_, computation_type, stream);
|
||||
|
||||
// If the output buffer already contains a bias then autotune into a
|
||||
// scratch buffer. This avoids overwriting the bias buffer. The scratch
|
||||
// buffer may contain arbitrary garbage values.
|
||||
se::DeviceMemoryBase scratch_data = output_data;
|
||||
std::unique_ptr<se::TemporaryDeviceMemory<char>> scratch_mem;
|
||||
if (beta_ != 0.0) {
|
||||
auto temp_status = stream->AllocateTemporaryArray<char>(
|
||||
ShapeUtil::ByteSizeOf(output_shape_));
|
||||
if (!temp_status.ok()) {
|
||||
return false;
|
||||
}
|
||||
scratch_mem = std::move(temp_status).ValueOrDie();
|
||||
scratch_data = scratch_mem->device_memory();
|
||||
}
|
||||
const MatrixDescriptor scratch_descriptor(
|
||||
scratch_data, false, output_num_cols, output_num_rows, batch_size);
|
||||
|
||||
StatusOr<se::blas::AlgorithmType> best_algorithm = GetGemmAutotuneFn(
|
||||
element_type)(lhs_matrix, rhs_matrix, scratch_descriptor, alpha_,
|
||||
beta_, computation_type, stream);
|
||||
autotune_it =
|
||||
autotune_results_.insert({device_name, best_algorithm}).first;
|
||||
|
||||
@ -421,18 +446,19 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
VLOG(2) << "Using algorithm " << algorithm
|
||||
<< " chosen by autotuning on GemmThunk " << thunk_name();
|
||||
return GetGemmWithAlgorithmFn(element_type)(
|
||||
lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type,
|
||||
algorithm, stream,
|
||||
lhs_matrix, rhs_matrix, output_matrix, alpha_, beta_,
|
||||
computation_type, algorithm, stream,
|
||||
/*output_profile_result=*/nullptr);
|
||||
}
|
||||
|
||||
// Autotune will fail when CUDA 8 and GPU sm_50 or older are used.
|
||||
// Use the older Gemm API in this case.
|
||||
return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
|
||||
alpha_, stream);
|
||||
alpha_, beta_, stream);
|
||||
};
|
||||
|
||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(
|
||||
implements_whole_instruction_ ? hlo_instruction() : nullptr);
|
||||
bool launch_ok;
|
||||
if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) {
|
||||
launch_ok = launch(lhs_descriptor, rhs_descriptor,
|
||||
|
@ -41,8 +41,9 @@ class GemmThunk : public Thunk {
|
||||
const BufferAllocation::Slice& rhs_buffer,
|
||||
const BufferAllocation::Slice& output_buffer,
|
||||
const Shape& lhs_shape, const Shape& rhs_shape,
|
||||
const Shape& output_shape, double alpha,
|
||||
const HloInstruction* hlo_instruction);
|
||||
const Shape& output_shape, double alpha, double beta,
|
||||
const HloInstruction* hlo_instruction,
|
||||
bool implements_whole_instruction);
|
||||
|
||||
GemmThunk(const GemmThunk&) = delete;
|
||||
GemmThunk& operator=(const GemmThunk&) = delete;
|
||||
@ -70,6 +71,9 @@ class GemmThunk : public Thunk {
|
||||
const Shape output_shape_;
|
||||
|
||||
const double alpha_;
|
||||
const double beta_;
|
||||
|
||||
const bool implements_whole_instruction_;
|
||||
|
||||
// Maps device names (StreamExecutor::DeviceDescription::name()) to autotune
|
||||
// results. The map's value is the best algorithm we've found for this thunk
|
||||
|
@ -124,7 +124,8 @@ GpuHloOrdering::GpuHloOrdering(
|
||||
for (auto* computation : module->computations()) {
|
||||
if (computation != module->entry_computation() &&
|
||||
!computation->IsFusionComputation()) {
|
||||
predecessors_.emplace(computation, computation->ComputeReachability());
|
||||
predecessors_.emplace(computation,
|
||||
HloReachabilityMap::Build(computation));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -179,6 +179,10 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||
IsIEEEFloatingPointScalarConstant(alpha->operand(0))) {
|
||||
return true;
|
||||
}
|
||||
} else if (consumer->operand_count() == 2 &&
|
||||
consumer->opcode() == HloOpcode::kAdd) {
|
||||
// Fuse a bias add into the output of the dot.
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -331,6 +331,33 @@ TEST_F(InstructionFusionTest, DotOutputFusion) {
|
||||
op::Broadcast(op::Constant())));
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, DotOutputFusionBiasAdd) {
|
||||
auto module = ParseHloString(R"(
|
||||
HloModule test_module
|
||||
ENTRY OutputFusion {
|
||||
alpha = f32[] constant(3)
|
||||
broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={}
|
||||
p0 = f32[4,3]{1,0} parameter(0)
|
||||
p1 = f32[4,3]{1,0} parameter(1)
|
||||
p2 = f32[4,4]{1,0} parameter(2)
|
||||
transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0}
|
||||
dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
ROOT add = f32[4,4] add(dot, p2)
|
||||
})")
|
||||
.ValueOrDie();
|
||||
|
||||
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Fusion());
|
||||
EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput);
|
||||
EXPECT_THAT(root->fused_expression_root(),
|
||||
op::Add(op::Dot(op::Parameter(), op::Transpose(op::Parameter())),
|
||||
op::Parameter()));
|
||||
}
|
||||
|
||||
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
|
||||
// duplicated and fused into both reduces.
|
||||
TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
|
||||
|
@ -38,10 +38,9 @@ namespace gpu {
|
||||
|
||||
namespace {
|
||||
|
||||
// Return whether the given shape is a matrix with no padding.
|
||||
bool IsRank2WithNoPadding(const Shape& shape, int64 batch_dimensions_size) {
|
||||
return ShapeUtil::Rank(shape) == batch_dimensions_size + 2 &&
|
||||
!LayoutUtil::IsPadded(shape);
|
||||
// Return whether the given shape is rank 2 excluding the batch dimensions.
|
||||
bool IsRank2(const Shape& shape, int64 batch_dimensions_size) {
|
||||
return ShapeUtil::Rank(shape) == batch_dimensions_size + 2;
|
||||
}
|
||||
|
||||
// In a gemm operation where output = lhs * rhs, check whether the given shapes
|
||||
@ -56,10 +55,9 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
|
||||
bool type_is_allowed =
|
||||
(output_primitive_type == F16 || output_primitive_type == F32 ||
|
||||
output_primitive_type == F64 || output_primitive_type == C64);
|
||||
return type_is_allowed &&
|
||||
IsRank2WithNoPadding(lhs_shape, batch_dimensions_size) &&
|
||||
IsRank2WithNoPadding(rhs_shape, batch_dimensions_size) &&
|
||||
IsRank2WithNoPadding(output_shape, batch_dimensions_size) &&
|
||||
return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) &&
|
||||
IsRank2(rhs_shape, batch_dimensions_size) &&
|
||||
IsRank2(output_shape, batch_dimensions_size) &&
|
||||
!ShapeUtil::IsZeroElementArray(lhs_shape) &&
|
||||
!ShapeUtil::IsZeroElementArray(rhs_shape);
|
||||
}
|
||||
@ -93,7 +91,8 @@ bool ImplementedAsGemm(const HloInstruction& hlo) {
|
||||
|
||||
if (hlo.opcode() == HloOpcode::kFusion &&
|
||||
hlo.fusion_kind() == HloInstruction::FusionKind::kOutput &&
|
||||
hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply) {
|
||||
(hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply ||
|
||||
hlo.fused_expression_root()->opcode() == HloOpcode::kAdd)) {
|
||||
// Try to find the dot inside the output fusion node.
|
||||
const HloInstruction* dot = hlo.fused_expression_root()->operand(0);
|
||||
if (dot->opcode() != HloOpcode::kDot) {
|
||||
|
@ -697,15 +697,11 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
|
||||
Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
||||
// kFusion for library calls should be handled by
|
||||
// IrEmitterUnnested::HandleFusion.
|
||||
CHECK(HloInstruction::FusionKind::kLoop == fusion->fusion_kind());
|
||||
|
||||
std::vector<llvm_ir::IrArray> parameter_arrays;
|
||||
for (HloInstruction* operand : fusion->operands()) {
|
||||
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
|
||||
}
|
||||
CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
|
||||
GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
|
||||
GetNestedComputer());
|
||||
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
|
||||
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
|
||||
&elemental_emitter);
|
||||
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
|
||||
|
||||
return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator());
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user