Merge branch 'master' into tflite

This commit is contained in:
BY Shen 2018-11-06 12:35:23 +08:00 committed by GitHub
commit 31c345a9fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1995 changed files with 25699 additions and 14084 deletions

8
.gitignore vendored
View File

@ -24,10 +24,10 @@ Pods
Podfile.lock
*.pbxproj
*.xcworkspacedata
/tensorflow/contrib/lite/tools/make/downloads/**
/tensorflow/contrib/lite/gen/**
/tensorflow/contrib/lite/examples/ios/simple/data/*.txt
/tensorflow/contrib/lite/examples/ios/simple/data/*.tflite
/tensorflow/lite/tools/make/downloads/**
/tensorflow/lite/gen/**
/tensorflow/lite/examples/ios/simple/data/*.txt
/tensorflow/lite/examples/ios/simple/data/*.tflite
xcuserdata/**
/api_init_files_list.txt
/estimator_api_init_files_list.txt

2
BUILD
View File

@ -2,5 +2,7 @@ exports_files(
[
"LICENSE",
"ACKNOWLEDGEMENTS",
"configure",
"configure.py",
],
)

View File

@ -258,8 +258,8 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A
* Update `tf.keras` to the Keras 2.1.6 API.
* Added [`tf.keras.layers.CuDNNGRU`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNGRU) and [`tf.keras.layers.CuDNNLSTM`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNLSTM) layers. [Try it](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb?linkId=53292082).
* Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/boosted_trees).
* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/lite)
for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/README.md)
* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/lite)
for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/toco/README.md)
has been expanded, and the command line interface (AKA: `toco`, `tflite_convert`) is once again
included in the standard `pip` installation.
* Improved data-loading and text processing with:
@ -562,7 +562,7 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田
## Major Features And Improvements
* [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager)
preview version is now available.
* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/lite)
* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/lite)
dev preview is now available.
* CUDA 9.0 and cuDNN 7 support.
* Accelerated Linear Algebra (XLA):
@ -909,7 +909,7 @@ See also [TensorBoard 0.1.4](https://github.com/tensorflow/tensorboard/releases/
* Adds tf.contrib.nn.rank_sampled_softmax_loss, a sampled-softmax variant that can improve rank loss.
* `tf.contrib.metrics`.{streaming_covariance,streaming_pearson_correlation} modified to return nan when they have seen less or equal to 1 unit of weight.
* Adds time series models to contrib. See contrib/timeseries/README.md for details.
* Adds FULLY_CONNECTED Op to tensorflow/contrib/lite/schema.fbs
* Adds FULLY_CONNECTED Op to tensorflow/lite/schema.fbs
## Known Issues
* Tensorflow_gpu compilation fails with Bazel 0.5.3.

View File

@ -1418,11 +1418,16 @@ def set_mpi_home(environ_cp):
def valid_mpi_path(mpi_home):
exists = (
os.path.exists(os.path.join(mpi_home, 'include')) and
os.path.exists(os.path.join(mpi_home, 'lib')))
(os.path.exists(os.path.join(mpi_home, 'lib')) or
os.path.exists(os.path.join(mpi_home, 'lib64')) or
os.path.exists(os.path.join(mpi_home, 'lib32'))))
if not exists:
print('Invalid path to the MPI Toolkit. %s or %s cannot be found' %
(os.path.join(mpi_home, 'include'),
os.path.exists(os.path.join(mpi_home, 'lib'))))
print(
'Invalid path to the MPI Toolkit. %s or %s or %s or %s cannot be found'
% (os.path.join(mpi_home, 'include'),
os.path.exists(os.path.join(mpi_home, 'lib')),
os.path.exists(os.path.join(mpi_home, 'lib64')),
os.path.exists(os.path.join(mpi_home, 'lib32'))))
return exists
_ = prompt_loop_or_load_from_env(
@ -1463,8 +1468,17 @@ def set_other_mpi_vars(environ_cp):
if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')):
symlink_force(
os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so')
elif os.path.exists(os.path.join(mpi_home, 'lib64/libmpi.so')):
symlink_force(
os.path.join(mpi_home, 'lib64/libmpi.so'), 'third_party/mpi/libmpi.so')
elif os.path.exists(os.path.join(mpi_home, 'lib32/libmpi.so')):
symlink_force(
os.path.join(mpi_home, 'lib32/libmpi.so'), 'third_party/mpi/libmpi.so')
else:
raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home)
raise ValueError(
'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' %
mpi_home, mpi_home, mpi_home)
def set_system_libs_flag(environ_cp):
@ -1681,4 +1695,3 @@ def main():
if __name__ == '__main__':
main()

View File

@ -55,4 +55,10 @@ except NameError:
# does not have 'python', 'core' directories. Then, it will be copied
# to tensorflow/ which does have these two directories.
pass
# Similarly for compiler. Do it separately to make sure we do this even if the
# others don't exist.
try:
del compiler
except NameError:
pass
# pylint: enable=undefined-variable

View File

@ -63,4 +63,10 @@ except NameError:
# does not have 'python', 'core' directories. Then, it will be copied
# to tensorflow/ which does have these two directories.
pass
# Similarly for compiler. Do it separately to make sure we do this even if the
# others don't exist.
try:
del compiler
except NameError:
pass
# pylint: enable=undefined-variable

View File

@ -199,6 +199,7 @@ tf_cuda_cc_test(
size = "small",
srcs = ["c_api_test.cc"],
data = [
":test_op1.so",
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
],
kernels = [":test_op_kernel"],
@ -283,8 +284,8 @@ tf_cc_test(
)
tf_custom_op_library(
name = "test_op.so",
srcs = ["test_op.cc"],
name = "test_op1.so",
srcs = ["test_op1.cc"],
)
tf_kernel_library(

View File

@ -8775,3 +8775,28 @@ void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
tensorflow::DeviceType(device_type), builder->BuildNodeDef(),
/* def = */ nullptr, /* kernel_class_name = */ nullptr);
}
const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
TF_Status* status) {
const tensorflow::OpDef* op_def = nullptr;
status->status =
tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def);
if (!status->status.ok()) return nullptr;
if (input_index >= op_def->input_arg_size() || input_index < 0) {
status->status = tensorflow::errors::InvalidArgument(
input_index, " out of range for ", op_name);
return nullptr;
}
const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index];
if (input_arg.number_attr().empty()) {
status->status = tensorflow::errors::NotFound(
op_name, " does not have number_attr() defined.");
return nullptr;
}
// The returned string is owned by OpRegistry, so liveness is not a concern.
return input_arg.number_attr().c_str();
}

View File

@ -202,6 +202,13 @@ TF_CAPI_EXPORT extern void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder,
TF_CAPI_EXPORT extern void TF_AttrBuilderCheckCanRunOnDevice(
TF_AttrBuilder* builder, const char* device_type, TF_Status* status);
// For argument number input_index, fetch the corresponding number_attr that
// needs to be updated with the argument length of the input list.
// Returns nullptr if there is any problem like op_name is not found, or the
// argument does not support this attribute type.
TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput(
const char* op_name, int input_index, TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -187,15 +187,26 @@ TEST(CAPI, LibraryLoadFunctions) {
// tf_cuda_cc_test() bazel rule and remove the next line.
if (!GPUDeviceName().empty()) return;
// Load the library.
TF_Status* status = TF_NewStatus();
TF_Library* lib =
TF_LoadLibrary("tensorflow/c/test_op.so", status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
ASSERT_EQ(TF_OK, code) << status_msg;
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
{
// Load the library.
TF_Status* status = TF_NewStatus();
TF_Library* lib =
TF_LoadLibrary("tensorflow/c/test_op1.so", status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
ASSERT_EQ(TF_OK, code) << status_msg;
// Test op list.
TF_Buffer op_list_buf = TF_GetOpList(lib);
tensorflow::OpList op_list;
EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
ASSERT_EQ(op_list.op_size(), 1);
EXPECT_EQ("TestCApi1", op_list.op(0).name());
TF_DeleteLibraryHandle(lib);
}
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
{
TF_Buffer* op_list_buffer = TF_GetAllOpList();
tensorflow::OpList op_list;
@ -210,19 +221,6 @@ TEST(CAPI, LibraryLoadFunctions) {
EXPECT_TRUE(found);
TF_DeleteBuffer(op_list_buffer);
}
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
{
// Test op list.
TF_Buffer op_list_buf = TF_GetOpList(lib);
tensorflow::OpList op_list;
EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
ASSERT_EQ(op_list.op_size(), 1);
EXPECT_EQ("TestCApi", op_list.op(0).name());
}
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
TF_DeleteLibraryHandle(lib);
}
void TestEncodeDecode(int line, const std::vector<string>& data) {
@ -2349,14 +2347,8 @@ TEST(TestApiDef, TestCreateApiDef) {
// tf_cuda_cc_test() bazel rule and remove the next line.
if (!GPUDeviceName().empty()) return;
TF_Status* status = TF_NewStatus();
TF_Library* lib =
TF_LoadLibrary("tensorflow/c/test_op.so", status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TF_Buffer* op_list_buf = TF_GetAllOpList();
status = TF_NewStatus();
TF_Status* status = TF_NewStatus();
auto* api_def_map = TF_NewApiDefMap(op_list_buf, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
@ -2376,7 +2368,6 @@ TEST(TestApiDef, TestCreateApiDef) {
TF_DeleteBuffer(api_def_buf);
TF_DeleteApiDefMap(api_def_map);
TF_DeleteBuffer(op_list_buf);
TF_DeleteLibraryHandle(lib);
}
TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
@ -2384,14 +2375,8 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
// tf_cuda_cc_test() bazel rule and remove the next line.
if (!GPUDeviceName().empty()) return;
TF_Status* status = TF_NewStatus();
TF_Library* lib =
TF_LoadLibrary("tensorflow/c/test_op.so", status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TF_Buffer* op_list_buf = TF_GetAllOpList();
status = TF_NewStatus();
TF_Status* status = TF_NewStatus();
auto* api_def_map = TF_NewApiDefMap(op_list_buf, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
@ -2422,7 +2407,6 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
TF_DeleteBuffer(api_def_buf);
TF_DeleteApiDefMap(api_def_map);
TF_DeleteBuffer(op_list_buf);
TF_DeleteLibraryHandle(lib);
}
class DummyKernel : public tensorflow::OpKernel {

View File

@ -69,7 +69,7 @@ tf_cuda_library(
name = "c_api_internal",
hdrs = ["c_api_internal.h"],
visibility = [
"//learning/deepmind/courier:__pkg__",
"//learning/deepmind/courier:__subpackages__",
"//tensorflow:internal",
],
deps = [

View File

@ -79,10 +79,6 @@ struct TFE_TensorHandle {
tensorflow::Device* op_device)
: handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {}
TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype,
tensorflow::EagerContext* ctx)
: handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {}
TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {}
tensorflow::TensorHandle* handle;

View File

@ -1,4 +1,4 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,17 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "absl/strings/string_view.h"
namespace tensorflow {
namespace stream_executor {
namespace port {
REGISTER_OP("TestCApi1").Doc(R"doc(Used to test C API)doc");
using StringPiece = absl::string_view;
} // namespace port
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
} // namespace tensorflow

View File

@ -170,6 +170,7 @@ cc_library_with_android_deps(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
],
)
@ -516,6 +517,8 @@ tf_gen_op_wrappers_cc(
":array_ops",
":const_op",
":math_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
],
)

View File

@ -190,11 +190,13 @@ cc_library(
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:shape_ops",
"//tensorflow/core/kernels:stack",
"//tensorflow/core/kernels:variable_ops",
"//tensorflow/core/kernels/data:generator_dataset_op",
"//tensorflow/core/kernels/data:iterator_ops",
"//tensorflow/core/kernels/data:prefetch_dataset_op",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
],
)
@ -240,6 +242,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:variable_ops",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
],
)
@ -499,6 +502,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -214,7 +214,8 @@ Status NodeRequiresCompilation(Node* n, bool* result) {
return errors::Internal("Could not find compilation device ",
device_type.type());
}
*result = registration->requires_compilation;
*result = registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways;
return Status::OK();
}

View File

@ -525,7 +525,6 @@ Predicate* PredicateFactory::MakeAndOrImpl(
op->GetOperands().begin(),
op->GetOperands().end());
} else {
std::vector<Predicate*> sub_ops_intersection;
common_inner_operands.clear();
absl::c_copy_if(op->GetOperands(),
std::back_inserter(common_inner_operands),

View File

@ -127,7 +127,8 @@ InductionVarInfo CreateInductionVariable(const Scope& root,
Output loop_cond =
ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr);
ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
latch.output_false);
Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"),
latch.output_true, increment_by);
Output next_iteration =
@ -191,7 +192,8 @@ DependentInductionVar CreateDependentLoopInvariantValue(
value, frame_name);
ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value});
ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
latch.output_false);
Output next_iteration = ops::NextIteration(
root.WithOpName(prefix + "/next_iteration"), latch.output_true);
CHECK(root.graph()

View File

@ -1122,8 +1122,11 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
fdef);
}
if (!reuse_existing_functions || library->Find(name) == nullptr) {
const FunctionDef* original_fdef = library->Find(name);
if (!reuse_existing_functions || original_fdef == nullptr) {
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
} else if (!FunctionDefsEqual(*original_fdef, fdef)) {
TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
}
return Status::OK();
}

View File

@ -51,6 +51,12 @@ xla::StatusOr<Node*> AddHostComputeKeyPlaceholder(
return n;
}
// Returns if the node is a XLA computation key placeholder.
bool IsKeyPlaceholderNode(const Node& n) {
return n.type_string() == "Placeholder" &&
absl::EndsWith(n.name(), "_key_placeholder");
}
// Returns nodes with given type.
std::vector<Node*> GatherNodesWithType(const Graph& g, const string& type) {
std::vector<Node*> result;
@ -107,6 +113,8 @@ xla::StatusOr<Node*> BuildRecvAtHostNode(
xla::StatusOr<Node*> ReplaceArgNodesWithRecvAtHostNode(
Graph* g, const string& oc_cluster_name,
std::vector<DataType>* recv_at_host_dtypes, Node* key_placeholder) {
// TODO(b/77601805): use out nodes for source node, instead of traversing all
// nodes.
std::vector<Node*> arg_nodes = GatherNodesWithType(*g, "_Arg");
TF_RETURN_IF_ERROR(GetArgDataTypes(arg_nodes, recv_at_host_dtypes));
TF_ASSIGN_OR_RETURN(
@ -218,6 +226,8 @@ xla::StatusOr<Node*> BuildSendFromHostNode(
xla::StatusOr<Node*> ReplaceRetNodesWithSendFromHostNode(
Graph* g, const string& oc_cluster_name,
std::vector<DataType>* send_from_host_dtypes, Node* key_placeholder) {
// TODO(b/77601805): use in nodes for sink node, instead of traversing all
// nodes.
std::vector<Node*> ret_nodes = GatherNodesWithType(*g, "_Retval");
TF_RETURN_IF_ERROR(GetRetDataTypes(ret_nodes, send_from_host_dtypes));
TF_ASSIGN_OR_RETURN(
@ -258,7 +268,7 @@ absl::optional<std::vector<PartialTensorShape>> GetInferredInputShapes(
return absl::nullopt;
}
const PartialTensorShape shape = shapes[e->dst_input()];
const PartialTensorShape shape = shapes[e->src_output()];
if (!shape.IsFullyDefined()) {
return absl::nullopt;
}
@ -411,8 +421,7 @@ Status ConstructHostGraph(
if (node_map.find(n) != node_map.end()) {
// Already copied this node.
copy = node_map.at(n);
} else if (n->type_string() == "Placeholder" &&
absl::EndsWith(n->name(), "_key_placeholder")) {
} else if (IsKeyPlaceholderNode(*n)) {
// Change a).
copy = key_placeholder;
node_map[n] = copy;
@ -691,8 +700,7 @@ Status RewriteOutsideCompilationSubgraphFn::operator()(
// Step 4: add XLA cluster and outside compilation attr.
for (Node* n : (*graph)->nodes()) {
if (n->type_string() == "Placeholder" &&
absl::EndsWith(n->name(), "_key_placeholder")) {
if (IsKeyPlaceholderNode(*n)) {
continue;
}

View File

@ -221,8 +221,8 @@ Status ConvertTensorFlowSliceToStaticShapedSlice(
.WithOpName("static_shaped_slice"),
slice_inputs_int64.input, slice_inputs_int64.begin, slice_size)
.node();
std::vector<int> compile_time_const_inputs;
compile_time_const_inputs.push_back(2);
std::vector<string> compile_time_const_inputs;
compile_time_const_inputs.push_back("size");
(*result)->AddAttr(kXlaCompileTimeConstantInputsAttr,
compile_time_const_inputs);
return status;
@ -314,15 +314,18 @@ Status FindAndRewriteSlices(Graph* g, bool* changed) {
Status IncreaseDynamismForAutoJitPass::Run(
const GraphOptimizationPassOptions& options) {
legacy_flags::MarkForCompilationPassFlags* flags =
legacy_flags::GetMarkForCompilationPassFlags();
if (flags->tf_xla_clustering_debug) {
dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass",
**options.graph, options.flib_def);
}
bool changed;
TF_RETURN_IF_ERROR(FindAndRewriteSlices(options.graph->get(), &changed));
if (changed) {
legacy_flags::MarkForCompilationPassFlags* flags =
legacy_flags::GetMarkForCompilationPassFlags();
if (flags->tf_xla_clustering_debug) {
dump_graph::DumpGraphToFile("increase_dynamism_for_auto_jit_pass",
**options.graph, options.flib_def);
}
if (changed && flags->tf_xla_clustering_debug) {
dump_graph::DumpGraphToFile("increase_dynamism_for_auto_jit_pass",
**options.graph, options.flib_def);
}
return Status::OK();

View File

@ -129,8 +129,8 @@ TEST(SliceToDynamicSliceRewriteTest, Basic) {
Op("ConcatV2"), AssignedDevice(kHostName),
Inputs(m_slice_size_0, Const(static_cast<int64>(500)), Const(zero_32))));
std::vector<int> compile_time_constant_inputs;
compile_time_constant_inputs.push_back(2);
std::vector<string> compile_time_constant_inputs;
compile_time_constant_inputs.push_back("size");
auto m_dynamic_slice = NodeWith(
Op("Slice"), AssignedDevice(kDeviceName),
Attr(kXlaCompileTimeConstantInputsAttr, compile_time_constant_inputs),

View File

@ -39,12 +39,22 @@ limitations under the License.
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/util/stream_executor_util.h"
// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
// in error case, it returns RET instead of void.
#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
do { \
::tensorflow::Status _s(__VA_ARGS__); \
if (!TF_PREDICT_TRUE(_s.ok())) { \
(CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
return RET; \
} \
} while (0)
namespace tensorflow {
namespace {
Status PlatformInfoFromContext(OpKernelConstruction* ctx,
XlaPlatformInfo* result) {
XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
DeviceType device_type = ctx->device_type();
se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr;
@ -76,16 +86,16 @@ Status PlatformInfoFromContext(OpKernelConstruction* ctx,
}
if (!device_allocator) {
TF_ASSIGN_OR_RETURN(se::Platform* const platform,
se::MultiPlatformManager::PlatformWithId(platform_id));
xla::StatusOr<se::Platform*> maybe_platform =
se::MultiPlatformManager::PlatformWithId(platform_id);
OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
xla_allocator = absl::make_unique<XlaAllocator>(
platform, ctx->device()->GetAllocator({}));
maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
}
*result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
std::move(xla_allocator), device_allocator);
return Status::OK();
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
std::move(xla_allocator), device_allocator);
}
// A closure describing how to run a compiled version of a TensorFlow function.
@ -179,9 +189,8 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
: OpKernel(ctx),
constants_(constants),
resources_(resources),
function_(function) {
OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
}
function_(function),
platform_info_(PlatformInfoFromContext(ctx)) {}
static Status BuildCompilationCache(OpKernelContext* ctx,
const XlaPlatformInfo& platform_info,
@ -333,18 +342,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
}
namespace {
// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
// in error case, it returns RET instead of void.
#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
do { \
::tensorflow::Status _s(__VA_ARGS__); \
if (!TF_PREDICT_TRUE(_s.ok())) { \
(CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
return RET; \
} \
} while (0)
// Helper static functions to construct parameters for
// XlaLocalLaunchBase constructor from OpKernelConstruction.
std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
@ -381,7 +378,12 @@ NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
return *func;
}
#undef OP_REQUIRES_OK_RETURN
bool MustCompileAttr(OpKernelConstruction* ctx) {
bool must_compile;
OP_REQUIRES_OK_RETURN(ctx, false,
ctx->GetAttr("must_compile", &must_compile));
return must_compile;
}
} // namespace
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
@ -396,10 +398,9 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
: OpKernel(ctx),
constants_(ConstantsVector(ctx)),
resources_(ResourcesVector(ctx)),
function_(FunctionAttr(ctx)) {
OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("must_compile", &must_compile_));
}
function_(FunctionAttr(ctx)),
platform_info_(PlatformInfoFromContext(ctx)),
must_compile_(MustCompileAttr(ctx)) {}
void XlaCompileOp::Compute(OpKernelContext* ctx) {
VLOG(3) << "XlaCompileOp " << def().name()
@ -409,13 +410,30 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
xla::LocalExecutable* executable;
std::map<int, OptionalTensor> variables;
if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation) {
bool cannot_compile_cluster;
{
mutex_lock guard(cannot_compile_cluster_mu_);
cannot_compile_cluster = cannot_compile_cluster_;
}
if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
cannot_compile_cluster) {
executable = nullptr;
} else {
OP_REQUIRES_OK(ctx, CompileToLocalExecutable(
ctx, function_, platform_info_, resources_,
constants_, /*lazy=*/!must_compile_, &client,
&variables, &kernel, &executable));
Status status = CompileToLocalExecutable(
ctx, function_, platform_info_, resources_, constants_,
/*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
OP_REQUIRES_OK(ctx, status);
}
if (status.code() == error::UNIMPLEMENTED) {
LOG(WARNING) << "Compilation failed:" << status.ToString()
<< ". Falling back to TF function call.";
executable = nullptr;
mutex_lock guard(cannot_compile_cluster_mu_);
cannot_compile_cluster_ = true;
}
}
AllocatorAttributes host_alloc_attrs;
@ -452,9 +470,8 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
ctx->set_output(1, compilation_successful);
}
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
}
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
: OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {}
void XlaRunOp::Compute(OpKernelContext* ctx) {
VLOG(3) << "XlaRunOp " << def().name();

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
#include <atomic>
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
@ -33,6 +35,7 @@ namespace tensorflow {
class XlaPlatformInfo {
public:
XlaPlatformInfo() : device_type_("") {}
XlaPlatformInfo(XlaPlatformInfo&&) = default;
explicit XlaPlatformInfo(const DeviceType device_type,
se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata,
@ -110,12 +113,12 @@ class XlaLocalLaunchBase : public OpKernel {
protected:
// Indexes of compile-time constant inputs
std::vector<int> constants_;
const std::vector<int> constants_;
// Indexes of resource inputs
std::vector<int> resources_;
const std::vector<int> resources_;
NameAttrList function_;
XlaPlatformInfo platform_info_;
const NameAttrList function_;
const XlaPlatformInfo platform_info_;
};
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
@ -144,15 +147,23 @@ class XlaCompileOp : public OpKernel {
private:
// Indexes of compile-time constant inputs
std::vector<int> constants_;
const std::vector<int> constants_;
// Indexes of resource inputs
std::vector<int> resources_;
const std::vector<int> resources_;
NameAttrList function_;
const NameAttrList function_;
XlaPlatformInfo platform_info_;
bool must_compile_;
const bool must_compile_;
// cannot_compile_cluster_ is set to true if XLA returns an Unimplemented
// error when compiling the cluster this _XlaCompile is supposed to compile.
// If `cannot_compile_cluster_` is true then we avoid compiling this cluster
// on any future calls to _XlaCompile.
bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false;
mutex cannot_compile_cluster_mu_;
};
class XlaRunOp : public OpKernel {
@ -162,7 +173,7 @@ class XlaRunOp : public OpKernel {
void Compute(OpKernelContext* ctx) override;
private:
XlaPlatformInfo platform_info_;
const XlaPlatformInfo platform_info_;
};
} // namespace tensorflow

View File

@ -49,6 +49,25 @@ limitations under the License.
namespace tensorflow {
namespace {
// Aggregates information about what kinds of ops are allowed.
struct OperationFilter {
// Whether resource variable ops are allowed. We do not allow resource
// variable ops in called functions (either as direct TF calls or as higher
// order control flow ops) because we do not yet model their memory effects in
// jit/resource_variable_safety_analysis.
bool allow_resource_ops;
// Whether stateful RNG ops are allowed. XLA's RNG does not have the same
// seeding behavior as TensorFlow's RNG (b/34749654). So we avoid
// auto-clustering stateful RNG ops.
bool allow_stateful_rng_ops;
};
bool IsStatefulRandomOp(absl::string_view op_name) {
return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
op_name == "TruncatedNormal";
}
bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
@ -101,7 +120,7 @@ const int kMaxRecursionDepth = 10;
bool IsCompilableCall(const NodeDef& call_def,
const DeviceType& jit_device_type,
bool allow_resource_ops, int depth,
const OperationFilter& op_filter, int depth,
FunctionLibraryRuntime* lib_runtime);
// Tests whether 'while_node' is a completely compilable loop.
@ -109,7 +128,7 @@ bool IsCompilableCall(const NodeDef& call_def,
// while loop to be compilable.
bool IsCompilableWhile(const Node& while_node,
const DeviceType& jit_device_type,
bool allow_resource_ops, int depth,
const OperationFilter& op_filter, int depth,
FunctionLibraryRuntime* lib_runtime) {
const NameAttrList* name_attr;
NodeDef call;
@ -124,7 +143,7 @@ bool IsCompilableWhile(const Node& while_node,
call.set_name("while_cond");
call.set_op(cond_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1,
if (!IsCompilableCall(call, jit_device_type, op_filter, depth + 1,
lib_runtime)) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": can't compile loop condition: " << cond_func;
@ -140,7 +159,7 @@ bool IsCompilableWhile(const Node& while_node,
call.set_name("while_body");
call.set_op(body_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1,
if (!IsCompilableCall(call, jit_device_type, op_filter, depth + 1,
lib_runtime)) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": can't compile loop body: " << body_func;
@ -154,7 +173,7 @@ bool IsCompilableWhile(const Node& while_node,
// compilable.
bool IsCompilableCall(const NodeDef& call_def,
const DeviceType& jit_device_type,
bool allow_resource_ops, int depth,
const OperationFilter& op_filter, int depth,
FunctionLibraryRuntime* lib_runtime) {
if (depth > kMaxRecursionDepth) {
VLOG(2) << "Rejecting " << call_def.op()
@ -195,16 +214,20 @@ bool IsCompilableCall(const NodeDef& call_def,
continue;
if (node->type_string() == "While") {
// Handle functional While loop.
return IsCompilableWhile(*node, jit_device_type, allow_resource_ops,
depth + 1, lib_runtime);
return IsCompilableWhile(*node, jit_device_type, op_filter, depth + 1,
lib_runtime);
}
if (!allow_resource_ops &&
if (!op_filter.allow_resource_ops &&
(HasResourceInput(*node) || HasResourceOutput(*node))) {
return false;
}
if (!op_filter.allow_stateful_rng_ops &&
IsStatefulRandomOp(node->type_string())) {
return false;
}
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, allow_resource_ops,
depth + 1, lib_runtime)) {
!IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1,
lib_runtime)) {
VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op "
<< node->name() << ": " << node->def().ShortDebugString();
return false;
@ -426,14 +449,28 @@ Status FindCompilationCandidates(
CHECK(
XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration));
DeviceType jit_device_type(registration->compilation_device_name);
OperationFilter op_filter;
op_filter.allow_resource_ops = registration->compile_resource_ops;
op_filter.allow_stateful_rng_ops =
(registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways);
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type,
registration->compile_resource_ops, 0, lib_runtime)) {
!IsCompilableCall(node->def(), jit_device_type, op_filter, 0,
lib_runtime)) {
VLOG(2) << "Rejecting " << node->name() << ": unsupported op "
<< node->type_string();
continue;
}
if (!registration->compile_resource_ops &&
if (!op_filter.allow_stateful_rng_ops &&
IsStatefulRandomOp(node->type_string())) {
VLOG(2) << "Rejecting " << node->name() << ": stateful random operation";
continue;
}
if (!op_filter.allow_resource_ops &&
(HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
// We don't have a way of returning values of type DT_RESOURCE from XLA
// computations so we avoid auto-clustering nodes producing DT_RESOURCE.
@ -444,6 +481,7 @@ Status FindCompilationCandidates(
<< node->type_string();
continue;
}
if (compile_time_const_nodes[node->id()]) {
const OpDef* op_def;
TF_RETURN_IF_ERROR(
@ -501,9 +539,7 @@ Status FindCompilationCandidates(
// registration->compile_resource_ops is true for XLA_CPU/XLA_GPU but not
// for CPU/GPU.
if (node->type_string() == "While" &&
!IsCompilableWhile(*node, jit_device_type,
registration->compile_resource_ops, 0,
lib_runtime)) {
!IsCompilableWhile(*node, jit_device_type, op_filter, 0, lib_runtime)) {
continue;
}
// _Arg nodes in a top-level function represent feeds.
@ -563,10 +599,12 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
&registration));
DeviceType jit_device_type(registration->compilation_device_name);
// We can always *compile* resource operations, even if we are sometimes
// unable to auto-cluster them.
const bool compile_resource_ops = true;
return IsCompilableCall(ndef, jit_device_type, compile_resource_ops, 0, flr);
// We can always *compile* resource operations and stateful RNGs, even if we
// are sometimes unable to auto-cluster them.
OperationFilter op_filter;
op_filter.allow_resource_ops = true;
op_filter.allow_stateful_rng_ops = true;
return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr);
}
Status MarkForCompilationPass::Run(
@ -577,10 +615,8 @@ Status MarkForCompilationPass::Run(
GetGlobalJitLevel(options);
legacy_flags::MarkForCompilationPassFlags* flags =
legacy_flags::GetMarkForCompilationPassFlags();
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
bool fusion_only = flags->tf_xla_fusion_only;
VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit;
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit;
const FunctionLibraryDefinition* fld = options.flib_def;
@ -599,9 +635,6 @@ Status MarkForCompilationPass::Run(
return false;
}
// If this device requires a JIT, we must say yes.
if (registration->requires_compilation) return true;
// If there is a _XlaCompile annotation, use its value.
bool compile = false;
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
@ -638,18 +671,21 @@ Status MarkForCompilationPass::Run(
return false;
}
// Otherwise use the value of global_jit_level.
// Ignore enable_jit_by_default if global jit compilation for CPU
// is explicitly requested via tf_xla_cpu_global_jit flag
bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU;
// Otherwise use the value of global_jit_level and the device's
// autoclustering policy.
bool should_compile =
(ignore_registration || registration->enable_jit_by_default) &&
global_jit_level != OptimizerOptions::OFF;
registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways ||
(registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
global_jit_level != OptimizerOptions::OFF);
if (!should_compile) {
if (global_jit_level == OptimizerOptions::OFF) {
VLOG(2) << "Rejecting " << node->name() << ": global jit disabled.";
} else {
VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled.";
VLOG(2)
<< "Rejecting " << node->name()
<< ": autoclustering for device only when requested explicitly.";
}
}
return should_compile;
@ -1037,12 +1073,10 @@ Status MarkForCompilationPass::RunImpl(
XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration);
// Compile if this is a cluster of >= min_cluster_size compilable operators.
// Also, always compile if the operator is placed on a device that requires
// compilation, or if it contains at least one op that is marked for
// Also, always compile if it contains at least one op that is marked for
// compilation that is not an Identity op.
if (effective_cluster_sizes[cluster] >= min_cluster_size ||
(effective_cluster_sizes[cluster] > 0 && marked_for_compilation) ||
registration->requires_compilation) {
(effective_cluster_sizes[cluster] > 0 && marked_for_compilation)) {
string& name = cluster_names[cluster];
if (name.empty()) {

View File

@ -923,9 +923,8 @@ TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_NE(clusters["test/shape_rng"], "");
EXPECT_NE(clusters["test/reshape"], "");
EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]);
EXPECT_EQ(clusters["test/shape_rng"], "");
EXPECT_EQ(clusters["test/reshape"], "");
}
TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
@ -1061,5 +1060,48 @@ TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) {
// Improved Heuristics should prevent this probably.
EXPECT_EQ(clusters["MatMulSource_dev0"], clusters["MatMul0_dev0"]);
}
TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) {
absl::string_view xla_cpu_device =
"/job:worker/replica:0/task:0/device:XLA_CPU:0";
Scope root = Scope::NewRootScope().ExitOnError();
Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
Output c = ops::Add(root.WithOpName("test/c"), a, b);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
for (Node* n : graph->nodes()) {
if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
n->set_assigned_device_name(string(xla_cpu_device));
}
}
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_NE(clusters["test/a"], "");
EXPECT_NE(clusters["test/b"], "");
EXPECT_NE(clusters["test/c"], "");
}
TEST(XlaCompilationTest, DontAutoclusterStatefulRandomOp) {
Scope root = Scope::NewRootScope().ExitOnError();
Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
Output c = ops::Add(root.WithOpName("test/c"), a, b);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_EQ(clusters["test/a"], "");
EXPECT_EQ(clusters["test/b"], "");
}
} // namespace
} // namespace tensorflow

View File

@ -485,6 +485,16 @@ std::pair<string, AttrValue> impl::AttrLiteralHelper(
return {int_list_attr.first, attr_value};
}
std::pair<string, AttrValue> impl::AttrLiteralHelper(
const std::pair<string, absl::Span<const string>>& string_list_attr) {
AttrValue attr_value;
AttrValue::ListValue* list = attr_value.mutable_list();
for (string s : string_list_attr.second) {
list->add_s(s);
}
return {string_list_attr.first, attr_value};
}
impl::NodeMatcherProperties impl::Attr(std::pair<string, AttrValue> attr) {
impl::NodeMatcherProperties props;
props.set_attr(std::move(attr));

View File

@ -170,6 +170,9 @@ std::pair<string, AttrValue> AttrLiteralHelper(
std::pair<string, AttrValue> AttrLiteralHelper(
const std::pair<string, absl::Span<const int>>& int_list_attr);
std::pair<string, AttrValue> AttrLiteralHelper(
const std::pair<string, absl::Span<const string>>& string_list_attr);
} // namespace impl
// -----------------------------------------------------------------------------

View File

@ -210,7 +210,8 @@ bool IsIntraClusterEdge(const Edge& edge) {
bool IsMustCompileDevice(const DeviceType& device_type) {
const XlaOpRegistry::DeviceRegistration* registration;
if (XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
return registration->requires_compilation;
return registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways;
}
return false;

View File

@ -173,9 +173,7 @@ Status XlaCompileOnDemandOp::Compile(
XlaCompiler::Options options;
options.device_type = metadata.jit_device_type();
options.client = metadata.client();
auto flib_def = absl::make_unique<FunctionLibraryDefinition>(
OpRegistry::Global(), FunctionDefLibrary{});
options.flib_def = flib_def.get();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.shape_representation_fn = metadata.shape_representation_fn();
XlaCompiler::CompileOptions compile_options;

View File

@ -42,8 +42,10 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
registration.requires_compilation = !compile_on_demand;
registration.enable_jit_by_default = false;
registration.autoclustering_policy =
compile_on_demand
? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested
: XlaOpRegistry::AutoclusteringPolicy::kAlways;
registration.compile_resource_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
@ -60,7 +62,6 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
options.device_name = DEVICE_XLA_CPU;
options.device_ordinal = 0;
options.compilation_device_name = DEVICE_CPU_XLA_JIT;
options.transfer_as_literal = false;
options.use_multiple_streams = false;
auto device = absl::make_unique<XlaDevice>(session_options, options);
devices->push_back(device.release());

View File

@ -201,12 +201,18 @@ XlaDevice::XlaDevice(const SessionOptions& session_options,
jit_device_name_(options.compilation_device_name),
platform_(options.platform),
use_multiple_streams_(options.use_multiple_streams),
transfer_as_literal_(options.transfer_as_literal),
shape_representation_fn_(options.shape_representation_fn) {
VLOG(1) << "Created XLA device " << options.compilation_device_name << " "
<< this;
thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device",
/*num_threads=*/1));
// We have multiple device to device streams to allow for some concurrency
// between transfers. The particular value of '4' is chosen fairly
// arbitrarily. It may be necessary to make this tunable via
// XlaDevice::Options.
static constexpr int kNumDeviceToDeviceStreams = 4;
device_to_device_streams_.resize(kNumDeviceToDeviceStreams);
}
XlaDevice::~XlaDevice() {
@ -274,8 +280,9 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
&need_new_device_context));
std::shared_ptr<se::Stream> host_to_device_stream = stream_;
std::shared_ptr<se::Stream> device_to_host_stream = stream_;
std::shared_ptr<se::Stream> host_to_device_stream;
std::shared_ptr<se::Stream> device_to_host_stream;
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams;
if (use_multiple_streams_) {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
&host_to_device_stream_,
@ -283,8 +290,18 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream",
&device_to_host_stream_,
&need_new_device_context));
for (std::shared_ptr<se::Stream>& stream : device_to_device_streams_) {
TF_RETURN_IF_ERROR(
EnsureStreamOkLocked(backend, "device_to_device_stream", &stream,
&need_new_device_context));
}
host_to_device_stream = host_to_device_stream_;
device_to_host_stream = device_to_host_stream_;
device_to_device_streams = device_to_device_streams_;
} else {
host_to_device_stream = stream_;
device_to_host_stream = stream_;
device_to_device_streams = {stream_};
}
if (!need_new_device_context) {
@ -302,8 +319,9 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
// ensures that the streams remain live for the duration of a run, even if
// an error is encountered and the streams are replaced with new ones.
device_context_ = new XlaDeviceContext(
stream_, host_to_device_stream, device_to_host_stream, client(),
transfer_as_literal_, shape_representation_fn_, thread_pool_.get());
stream_, std::move(host_to_device_stream),
std::move(device_to_host_stream), std::move(device_to_device_streams),
client(), shape_representation_fn_, thread_pool_.get());
VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
<< device_context_;

View File

@ -108,11 +108,6 @@ class XlaDevice : public LocalDevice {
// The name of the compilation device (e.g., "XLA_CPU_JIT");
string compilation_device_name;
// 'transfer_as_literal' is true if device<->host transfers must be done
// using XLA's TransferLiteral{To,From}Device interface. If false, we can
// use ThenMemcpy instead.
bool transfer_as_literal = false;
// If 'use_multiple_streams' is true, we create separate streams for
// compute, host-to-device, and device-to-host communication.
bool use_multiple_streams = false;
@ -188,6 +183,7 @@ class XlaDevice : public LocalDevice {
se::Platform* const platform_; // Not owned.
// Memory allocator associated with this device.
Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned.
// Stream associated with this device. Operations enqueued on this
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
@ -203,9 +199,11 @@ class XlaDevice : public LocalDevice {
// If use_multiple_streams_, device to host transfers are performed using this
// stream.
std::shared_ptr<se::Stream> device_to_host_stream_ GUARDED_BY(mu_);
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
const bool transfer_as_literal_;
// If use_multiple_streams_, transfers between different devices are performed
// using these streams.
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_
GUARDED_BY(mu_);
const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
// The device context accessed by all users of the XlaDevice, set by calls to

View File

@ -53,16 +53,17 @@ void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
XlaDeviceContext::XlaDeviceContext(
std::shared_ptr<se::Stream> compute_stream,
std::shared_ptr<se::Stream> host_to_device_stream,
std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
std::shared_ptr<se::Stream> device_to_host_stream,
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,
xla::LocalClient* client,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
thread::ThreadPool* thread_pool)
: stream_(std::move(compute_stream)),
host_to_device_stream_(std::move(host_to_device_stream)),
device_to_host_stream_(std::move(device_to_host_stream)),
device_to_device_streams_(std::move(device_to_device_streams)),
client_(client),
transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(std::move(shape_representation_fn)),
thread_pool_(thread_pool) {
CHECK(host_to_device_stream_ != nullptr);
@ -75,71 +76,6 @@ XlaDeviceContext::XlaDeviceContext(
}
}
Status XlaDeviceContext::TransferLiteralToDevice(const Tensor& host_tensor,
Tensor* device_tensor) const {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
host_tensor.shape(), &xla_shape));
// Create a reference to hold onto host_tensor until after the literal has
// been transferred. Also make sure the literal exists until the function
// asynchronously completes, as it will be wrapped in an xla::LiteralSlice.
TensorReference ref(host_tensor);
auto literal = std::make_shared<xla::BorrowingLiteral>(
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " "
<< shaped_buffer.ToString();
if (UseMultipleStreams() && !transfer_manager_->CanShapedBufferBeAccessedNow(
stream_->parent(), shaped_buffer)) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
host_to_device_stream_->ThenWaitFor(stream_.get());
}
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
host_to_device_stream_.get(), *literal, shaped_buffer));
if (UseMultipleStreams()) {
auto event = std::make_shared<se::Event>(stream_->parent());
TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
host_to_device_stream_->ThenRecordEvent(event.get());
xla_tensor->ResetDefinitionEvent(std::move(event),
host_to_device_stream_.get());
}
// Unref the host tensor, and capture the literal shared_ptr too so it goes
// out of scope when the lambda completes.
// We don't defer the call to done() onto the stream here, and the reasons why
// this is correct are subtle. We assume that:
// a) all consumers of the device tensor will wait for its definition event.
// b) if the tensor is destroyed, then the memory allocator will not hand out
// the same buffers until the transfer has completed.
host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); });
return Status::OK();
}
void XlaDeviceContext::TransferLiteralFromDevice(
Tensor* host_tensor, const Tensor& device_tensor,
const StatusCallback& done) const {
xla::MutableBorrowingLiteral literal;
TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(host_tensor, &literal));
const xla::ShapedBuffer& shaped_buffer =
XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
TensorReference ref(device_tensor);
transfer_manager_->TransferLiteralFromDevice(
device_to_host_stream_.get(), shaped_buffer, literal,
[=, &shaped_buffer](xla::Status status) {
ref.Unref();
done([&]() -> Status {
VLOG(1) << "Transfer from device as literal: "
<< shaped_buffer.ToString();
return status;
}());
});
}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
Tensor* device_tensor,
@ -158,54 +94,73 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
<< cpu_tensor->shape().DebugString() << " "
<< device_tensor->shape().DebugString();
void* src_ptr = const_cast<void*>(DMAHelper::base(cpu_tensor));
const int64 total_bytes = cpu_tensor->TotalBytes();
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
CHECK(xla_tensor);
xla::StatusOr<TensorShape> shape_or_status =
shape_representation_fn_(device_tensor->shape(), device_tensor->dtype());
if (!shape_or_status.ok()) {
done(shape_or_status.status());
Status status = [&]() -> Status {
TF_ASSIGN_OR_RETURN(TensorShape shape,
shape_representation_fn_(device_tensor->shape(),
device_tensor->dtype()));
// The device tensor should always be fresh.
TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
xla_tensor->set_host_tensor(*cpu_tensor);
TF_RETURN_IF_ERROR(
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
stream_->parent()->device_ordinal()));
xla::BorrowingLiteral literal(
static_cast<const char*>(DMAHelper::base(cpu_tensor)),
xla_tensor->shaped_buffer().on_host_shape());
VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
<< xla_tensor->shaped_buffer().ToString();
if (UseMultipleStreams() &&
!transfer_manager_->CanShapedBufferBeAccessedNow(
stream_->parent(), xla_tensor->shaped_buffer())) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
host_to_device_stream_->ThenWaitFor(stream_.get());
}
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
host_to_device_stream_.get(), literal, xla_tensor->shaped_buffer()));
if (UseMultipleStreams()) {
auto event = std::make_shared<se::Event>(stream_->parent());
TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
host_to_device_stream_->ThenRecordEvent(event.get());
xla_tensor->ResetDefinitionEvent(std::move(event),
host_to_device_stream_.get());
}
return Status::OK();
}();
if (!status.ok()) {
done(status);
return;
}
TensorShape shape = shape_or_status.ValueOrDie();
if (!xla_tensor->has_shaped_buffer()) {
Status s =
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
stream_->parent()->device_ordinal());
if (!s.ok()) {
done(s);
return;
}
}
Status status;
if (transfer_as_literal_) {
Tensor reshaped_cpu_tensor;
if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
done(errors::Internal(
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
return;
}
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
// Create a reference to hold onto cpu_tensor until after the literal has
// been transferred
TensorReference ref(*cpu_tensor);
if (UseMultipleStreams()) {
// Unref the host tensor when the transfer completes.
// We don't defer the call to done() onto the stream here, and the reasons
// why this is correct are subtle. We assume that:
// a) all consumers of the device tensor will wait for its definition event.
// b) if the tensor is destroyed, then the memory allocator will not hand
// out the same buffers until the transfer has completed.
host_to_device_stream_->ThenDoHostCallback([ref]() { ref.Unref(); });
done(status);
} else {
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
host_to_device_stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
// TODO(hpucha): Make this asynchronous.
Status block_status = host_to_device_stream_->BlockHostUntilDone();
if (!block_status.ok()) {
status = xla::InternalError(
"Failed to complete data transfer on stream %p: %s",
host_to_device_stream_.get(), block_status.error_message().c_str());
}
host_to_device_stream_->ThenDoHostCallback([ref, done]() {
ref.Unref();
done(Status::OK());
});
}
if (status.ok()) {
xla_tensor->set_host_tensor(*cpu_tensor);
}
done(status);
}
void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
@ -225,30 +180,31 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
<< cpu_tensor->shape().DebugString() << " "
<< device_tensor->shape().DebugString();
const int64 total_bytes = cpu_tensor->TotalBytes();
se::DeviceMemoryBase dev_src_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
void* dst_ptr = DMAHelper::base(cpu_tensor);
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream_.get());
Status status;
if (transfer_as_literal_) {
TransferLiteralFromDevice(cpu_tensor, *device_tensor, done);
return;
} else {
device_to_host_stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
// TODO(hpucha): Make this asynchronous.
Status block_status = device_to_host_stream_->BlockHostUntilDone();
if (!block_status.ok()) {
status = xla::InternalError(
"Failed to complete data transfer on stream %p: %s", stream_.get(),
block_status.error_message().c_str());
}
}
xla::MutableBorrowingLiteral literal;
TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(cpu_tensor, &literal));
done(status);
TensorReference ref(*device_tensor);
transfer_manager_->TransferLiteralFromDevice(
device_to_host_stream_.get(), xla_tensor->shaped_buffer(), literal,
[ref, xla_tensor, done](xla::Status status) {
done([&]() -> Status {
VLOG(1) << "Transfer from device as literal: "
<< xla_tensor->shaped_buffer().ToString();
return status;
}());
ref.Unref();
});
}
se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() {
DCHECK_GT(device_to_device_streams_.size(), 0);
absl::MutexLock lock(&mu_);
int stream = next_stream_;
next_stream_ = (next_stream_ + 1) % device_to_device_streams_.size();
return device_to_device_stream(stream);
}
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/global_data.h"
@ -50,7 +51,8 @@ class XlaDeviceContext : public DeviceContext {
std::shared_ptr<se::Stream> compute_stream,
std::shared_ptr<se::Stream> host_to_device_stream,
std::shared_ptr<se::Stream> device_to_host_stream,
xla::LocalClient* client, bool transfer_as_literal,
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,
xla::LocalClient* client,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
thread::ThreadPool* thread_pool);
@ -61,14 +63,26 @@ class XlaDeviceContext : public DeviceContext {
absl::string_view tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) override;
xla::LocalClient* client() const { return client_; }
se::Stream* stream() const { return stream_.get(); }
se::Stream* host_to_device_stream() const {
return host_to_device_stream_.get();
}
se::Stream* device_to_host_stream() const {
return device_to_host_stream_.get();
}
se::Stream* device_to_device_stream(int index) const {
return device_to_device_streams_.at(index).get();
}
xla::TransferManager* transfer_manager() const { return transfer_manager_; }
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
return shape_representation_fn_;
}
// Returns a device-to-device stream, in round-robin fashion.
se::Stream* GetDeviceToDeviceStream();
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
Tensor* device_tensor) const;
void TransferLiteralFromDevice(Tensor* host_tensor,
const Tensor& device_tensor,
const StatusCallback& done) const;
bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; }
// The main compute stream of the device, used to synchronize the transfer
@ -80,16 +94,22 @@ class XlaDeviceContext : public DeviceContext {
// The stream to use for transferring data from device to host. Can be
// idential to stream_, but must not be nullptr.
std::shared_ptr<se::Stream> device_to_host_stream_;
// Streams to use for transferring data directly between different devices,
// e.g., over NVLINK.
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_;
// For the underlying memory allocator and XLA's TransferManager.
xla::LocalClient* client_;
// Transfer manager, for marshalling data to and from the device.
xla::TransferManager* transfer_manager_;
// True if we must use XLA's TransferManager for correct device transfers.
const bool transfer_as_literal_;
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
// Thread pool used for running closures
thread::ThreadPool* thread_pool_;
absl::Mutex mu_;
int next_stream_ GUARDED_BY(mu_) = 0;
};
} // namespace tensorflow

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/sendrecv_ops.h"
#include "tensorflow/core/kernels/shape_ops.h"
#include "tensorflow/core/kernels/stack.h"
#include "tensorflow/core/kernels/variable_ops.h"
namespace tensorflow {
@ -257,9 +258,27 @@ class XlaAssignVariableOp : public OpKernel {
.Device(DEVICE) \
.TypeConstraint<string>("T") \
.HostMemory("input"), \
RetvalOp);
RetvalOp); \
\
REGISTER_KERNEL_BUILDER(Name("StackV2") \
.Device(DEVICE) \
.HostMemory("max_size") \
.HostMemory("handle"), \
StackOp); \
REGISTER_KERNEL_BUILDER(Name("StackPushV2") \
.Device(DEVICE) \
.HostMemory("handle") \
.TypeConstraint("T", TYPES), \
TemplatedStackPushOp</*allow_swapping=*/false>); \
REGISTER_KERNEL_BUILDER(Name("StackPopV2") \
.Device(DEVICE) \
.HostMemory("handle") \
.TypeConstraint("elem_type", TYPES), \
StackPopOp); \
REGISTER_KERNEL_BUILDER( \
Name("StackCloseV2").Device(DEVICE).HostMemory("handle"), StackCloseOp);
// TODO(phawkins): currently we do not register the QueueEnqueueMany,
// TODO(b/118881356): currently we do not register the QueueEnqueueMany,
// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
// and write the tensors they access in order to concatenate them into a batch.
// We would need either to call out to an XLA computation to perform the

View File

@ -37,8 +37,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
std::vector<Device*>* devices) {
XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.requires_compilation = true;
registration.enable_jit_by_default = false;
registration.autoclustering_policy =
XlaOpRegistry::AutoclusteringPolicy::kAlways;
registration.compile_resource_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
@ -59,7 +59,6 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
options.device_name = DEVICE_XLA_GPU;
options.device_ordinal = 0;
options.compilation_device_name = DEVICE_GPU_XLA_JIT;
options.transfer_as_literal = false;
options.use_multiple_streams = false;
auto device = absl::make_unique<XlaDevice>(session_options, options);

View File

@ -45,8 +45,8 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
registration.requires_compilation = true;
registration.enable_jit_by_default = false;
registration.autoclustering_policy =
XlaOpRegistry::AutoclusteringPolicy::kAlways;
registration.compile_resource_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
registration);
@ -60,7 +60,6 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
options.device_name = DEVICE_XLA_INTERPRETER;
options.device_ordinal = 0;
options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
options.transfer_as_literal = false;
options.use_multiple_streams = false;
auto device = absl::make_unique<XlaDevice>(session_options, options);
devices->push_back(device.release());

View File

@ -375,6 +375,27 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "resampler_ops_test",
size = "small",
srcs = ["resampler_ops_test.py"],
disabled_backends = [
# TODO(b/74459949) Support BatchDot in CPU backend.
"cpu",
"cpu_ondemand",
],
# TODO(b/112295522): figure out how to make OSS build pass.
tags = ["no_oss"],
deps = [
":xla_test",
"//tensorflow/contrib/resampler:resampler_ops",
"//tensorflow/contrib/resampler:resampler_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "dynamic_stitch_test",
size = "small",
@ -489,8 +510,6 @@ tf_xla_py_test(
name = "function_test",
size = "small",
srcs = ["function_test.py"],
# Functions are not implemented in the on-demand compilation model yet.
disabled_backends = "cpu_ondemand",
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -680,9 +699,6 @@ tf_xla_py_test(
name = "random_ops_test",
size = "small",
srcs = ["random_ops_test.py"],
disabled_backends = [
"cpu_ondemand",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -698,6 +714,10 @@ tf_xla_py_test(
size = "medium",
srcs = ["reduce_ops_test.py"],
shard_count = 5,
tags = [
# TODO(b/119059212): Re-enable this test in OSS.
"no_oss",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -713,7 +733,6 @@ tf_xla_py_test(
name = "reduce_window_test",
size = "small",
srcs = ["reduce_window_test.py"],
disabled_backends = ["cpu_ondemand"],
deps = [
":xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
@ -822,8 +841,6 @@ tf_xla_py_test(
name = "stack_ops_test",
size = "small",
srcs = ["stack_ops_test.py"],
# Stack ops are not implemented in the on-demand compilation model yet.
disabled_backends = "cpu_ondemand",
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -851,7 +868,7 @@ tf_xla_py_test(
size = "small",
srcs = ["tensor_array_ops_test.py"],
# TensorArray ops are not implemented in the on-demand compilation model yet.
disabled_backends = "cpu_ondemand",
disabled_backends = ["cpu_ondemand"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -872,7 +889,7 @@ tf_xla_py_test(
size = "small",
srcs = ["tensor_list_ops_test.py"],
# TensorList ops are not implemented in the on-demand compilation model yet.
disabled_backends = "cpu_ondemand",
disabled_backends = ["cpu_ondemand"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -952,7 +969,6 @@ tf_xla_py_test(
name = "while_test",
size = "small",
srcs = ["while_test.py"],
disabled_backends = ["cpu_ondemand"],
deps = [
":xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
@ -1109,6 +1125,7 @@ cc_library(
"//tensorflow/core:test",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
@ -1219,7 +1236,6 @@ tf_xla_py_test(
name = "xla_ops_test",
size = "medium",
srcs = ["xla_ops_test.py"],
disabled_backends = ["cpu_ondemand"],
deps = [
":xla_test",
"//tensorflow/compiler/tf2xla/python:xla",

View File

@ -178,6 +178,13 @@ class BinaryOpsTest(xla_test.XLATestCase):
[0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype),
expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype))
self._testBinary(
gen_nn_ops.leaky_relu_grad,
np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype),
np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype),
expected=np.array([0.2, 0.4, 0.6, 0.8, 1, 6, 7, 8, 9, 10],
dtype=dtype))
self._testBinary(
gen_nn_ops.softmax_cross_entropy_with_logits,
np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype),

View File

@ -50,6 +50,8 @@ def tf_xla_py_test(
"""
if disabled_backends == None:
disabled_backends = []
if type(disabled_backends) != "list":
fail("disabled_backends must be a list of strings", "disabled_backends")
enabled_backends = [b for b in all_backends() if b not in disabled_backends]
test_names = []

View File

@ -135,7 +135,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
self.assertAllCloseAccordingToType(
np.array([-2.60260963, -4.29698515]),
var0.eval(),
float_rtol=1e-5,
float_rtol=1e-4,
half_rtol=1e-2)
self.assertAllCloseAccordingToType(
np.array([-0.28432083, -0.56694895]),
@ -167,7 +167,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params
self.assertAllCloseAccordingToType(
np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5)
np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5,
float_rtol=1e-4)
self.assertAllCloseAccordingToType(
np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5)

View File

@ -448,8 +448,8 @@ class ResizeBilinearTest(xla_test.XLATestCase):
for dtype in self.float_types:
self._assertForwardOpMatchesExpected(
np.array([[1, 2]], dtype=dtype), [3, 3],
expected=np.array(
[[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]], dtype=np.float32))
expected=np.array([[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]],
dtype=np.float32))
def testAlignCorners1x2To3x2Grad(self):
for dtype in self.float_types:
@ -477,8 +477,8 @@ class ResizeBilinearTest(xla_test.XLATestCase):
for dtype in self.float_types:
self._assertForwardOpMatchesExpected(
np.array([[1, 2], [3, 4]], dtype=dtype), [3, 3],
expected=np.array(
[[1, 1.5, 2], [2, 2.5, 3], [3, 3.5, 4]], dtype=np.float32))
expected=np.array([[1, 1.5, 2], [2, 2.5, 3], [3, 3.5, 4]],
dtype=np.float32))
def testAlignCorners2x2To3x3Grad(self):
self._assertBackwardOpMatchesExpected(
@ -498,8 +498,8 @@ class ResizeBilinearTest(xla_test.XLATestCase):
np.array([[7, 13], [22, 4]], dtype=np.float32),
input_shape=[3, 3],
dtype=dtype,
expected=np.array(
[[7, 0, 13], [0, 0, 0], [22, 0, 4]], dtype=np.float32))
expected=np.array([[7, 0, 13], [0, 0, 0], [22, 0, 4]],
dtype=np.float32))
def testAlignCorners4x4To3x3(self):
for dtype in self.float_types:
@ -507,8 +507,8 @@ class ResizeBilinearTest(xla_test.XLATestCase):
np.array(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
dtype=dtype), [3, 3],
expected=np.array(
[[1, 2.5, 4], [7, 8.5, 10], [13, 14.5, 16]], dtype=np.float32))
expected=np.array([[1, 2.5, 4], [7, 8.5, 10], [13, 14.5, 16]],
dtype=np.float32))
def testAlignCorners4x4To3x3Grad(self):
for dtype in self.float_types:
@ -516,41 +516,39 @@ class ResizeBilinearTest(xla_test.XLATestCase):
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32),
input_shape=[4, 4],
dtype=dtype,
expected=np.array(
[[1, 1, 1, 3], [2, 1.25, 1.25, 3], [2, 1.25, 1.25, 3],
[7, 4, 4, 9]],
dtype=np.float32))
expected=np.array([[1, 1, 1, 3], [2, 1.25, 1.25, 3],
[2, 1.25, 1.25, 3], [7, 4, 4, 9]],
dtype=np.float32))
def testAlignCorners3x3To9x9(self):
for dtype in self.float_types:
self._assertForwardOpMatchesExpected(
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), [9, 9],
expected=np.array(
[[1.0, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], [
1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75
], [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], [
3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25
], [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], [
4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75
], [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], [
6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25
], [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]],
[[1.0, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00],
[1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75],
[2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50],
[3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25],
[4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00],
[4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75],
[5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50],
[6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25],
[7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]],
dtype=np.float32))
def testAlignCorners3x3To9x9Grad(self):
for dtype in self.float_types:
self._assertBackwardOpMatchesExpected(
np.array(
[[1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], [
1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75
], [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], [
3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25
], [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], [
4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75
], [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], [
6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25
], [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]],
dtype=np.float32),
np.array([[1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00],
[1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75],
[2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50],
[3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25],
[4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00],
[4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75],
[5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50],
[6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25],
[7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]],
dtype=np.float32),
input_shape=[3, 3],
dtype=dtype,
expected=np.array(
@ -571,12 +569,12 @@ class ResizeBilinearTest(xla_test.XLATestCase):
(np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array(
[[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)) * 15.0,
[16, 16],
expected=7 * (np.array(
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],
dtype=np.float32) + np.array(
[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11],
[12], [13], [14], [15]],
dtype=np.float32)),
expected=7 *
(np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],
dtype=np.float32) +
np.array([[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11],
[12], [13], [14], [15]],
dtype=np.float32)),
large_tolerance=True)
def testNonAlignCorners3x2To6x4(self):
@ -600,6 +598,26 @@ class ResizeBilinearTest(xla_test.XLATestCase):
expected=np.array(expected_data, dtype=dtype),
align_corners=False)
def testNonAlignCorners3x2To6x4Batch2(self):
input_data = [[[64, 32], [32, 64], [50, 100]], [[32, 16], [16, 32],
[25, 50]]]
expected_data = [[[64.0, 48.0, 32.0, 32.0], [48.0, 48.0, 48.0, 48.0],
[32.0, 48.0, 64.0, 64.0], [41.0, 61.5, 82.0, 82.0],
[50.0, 75.0, 100.0, 100.0], [50.0, 75.0, 100.0, 100.0]],
[[32.0, 24.0, 16.0, 16.0], [24.0, 24.0, 24.0, 24.0],
[16.0, 24.0, 32.0, 32.0], [20.5, 30.75, 41.0, 41.0],
[25.0, 37.5, 50.0, 50.0], [25.0, 37.5, 50.0, 50.0]]]
for dtype in self.float_types:
input_image = np.array(input_data, dtype=dtype)
expected = np.array(expected_data, dtype=dtype)
with self.cached_session() as sess, self.test_scope():
image = array_ops.placeholder(input_image.dtype)
resized = gen_image_ops.resize_bilinear(
image, [6, 4], align_corners=False)
out = sess.run(resized, {image: input_image[:, :, :, np.newaxis]})
self.assertAllClose(expected[:, :, :, np.newaxis], out)
class NonMaxSuppressionTest(xla_test.XLATestCase):
@ -804,5 +822,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(num_valid, 3)
self.assertAllClose(indices_tf[:num_valid], [0, 2, 4])
if __name__ == "__main__":
test.main()

View File

@ -45,6 +45,7 @@ limitations under the License.
#include <random>
#include <unordered_map>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
@ -2687,6 +2688,37 @@ TEST_F(OpTest, Reverse) {
});
}
TEST_F(OpTest, ReverseSequence) {
Repeatedly([this]() {
std::vector<int64> dims = RandomDims(/*min_rank=*/2);
auto type = Choose<DataType>(kAllXlaTypes);
int64 rank = dims.size();
// Choose random batch and sequence dimensions.
std::vector<int> shuffled_dim_ids(rank);
absl::c_iota(shuffled_dim_ids, 0);
absl::c_shuffle(shuffled_dim_ids, generator());
shuffled_dim_ids.resize(2);
int batch_dim = shuffled_dim_ids[0];
int seq_dim = shuffled_dim_ids[1];
int batch_size = dims[batch_dim];
int max_seq_len = dims[seq_dim];
std::vector<int32> seq_lens(batch_size);
std::uniform_int_distribution<int32> d(0, max_seq_len);
absl::c_generate(seq_lens, [&]() { return d(generator()); });
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("ReverseSequence")
.RandomInput(type, dims)
.Input(test::AsTensor<int32>(seq_lens))
.Attr("seq_dim", seq_dim)
.Attr("batch_dim", batch_dim)
.Attr("T", type)
.Attr("Tlen", DT_INT32));
});
}
TEST_F(OpTest, ReverseV2) {
Repeatedly([this]() {
auto type = Choose<DataType>(kAllXlaTypes);

View File

@ -0,0 +1,156 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for resampler ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.contrib import resampler
from tensorflow.contrib.resampler.ops import gen_resampler_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ResamplerOpsTest(xla_test.XLATestCase):
def _assertForwardOpMatchesExpected(self, image_np, warp_np, expected):
with self.test_session() as sess, self.test_scope():
input_image = array_ops.placeholder(image_np.dtype)
warp = array_ops.placeholder(warp_np.dtype)
resampled = resampler.resampler(input_image, warp, name='resampler')
out = sess.run(resampled, {input_image: image_np, warp: warp_np})
self.assertAllCloseAccordingToType(
expected, out, half_rtol=1e-2, bfloat16_rtol=3e-2)
def _assertBackwardOpMatchesExpected(self, input_np, warp_np, grad_output_np,
expected_grad_data, expected_grad_warp):
with self.cached_session() as sess, self.test_scope():
input_image = array_ops.placeholder(input_np.dtype)
warp = array_ops.placeholder(warp_np.dtype)
grad_output = array_ops.placeholder(grad_output_np.dtype)
grad_data, grad_warp = gen_resampler_ops.resampler_grad(
input_image, warp, grad_output)
grad_data_tf, grad_warp_tf = sess.run([grad_data, grad_warp], {
input_image: input_np,
warp: warp_np,
grad_output: grad_output_np
})
self.assertAllCloseAccordingToType(
expected_grad_warp, grad_warp_tf, half_rtol=1e-2, bfloat16_rtol=3e-2)
self.assertAllCloseAccordingToType(
expected_grad_data, grad_data_tf, half_rtol=1e-2, bfloat16_rtol=3e-2)
def testSimple(self):
for dtype in self.float_types:
input_shape = [1, 2, 2, 1]
input_rgb_data = [0, 5, 13, 54]
input_np = np.array(input_rgb_data, dtype=dtype).reshape(input_shape)
warp_shape = [1, 2]
warp_data = [0.7, 0.6]
warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape)
expected = [[26.42]]
self._assertForwardOpMatchesExpected(input_np, warp_np, expected)
grad_output = np.ones([1, 1], dtype=dtype)
expected_grad_data = [[[[0.12], [0.27999997]], [[0.18000001],
[0.42000002]]]]
expected_grad_warp = [[26.60000038, 38.20000076]]
self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output,
expected_grad_data,
expected_grad_warp)
def testMultiChannel(self):
for dtype in self.float_types:
input_shape = [1, 2, 2, 3]
input_rgb_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
input_np = np.array(input_rgb_data, dtype=dtype).reshape(input_shape)
warp_shape = [1, 2]
warp_data = [0.7, 0.6]
warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape)
expected = [[59.58000183, 146.94000244, 107.37999725]]
self._assertForwardOpMatchesExpected(input_np, warp_np, expected)
grad_output = np.ones([1, 3], dtype=dtype)
expected_grad_data = [[[[0.12, 0.12, 0.12],
[0.27999997, 0.27999997, 0.27999997]],
[[0.18000001, 0.18000001, 0.18000001],
[0.42000002, 0.42000002, 0.42000002]]]]
expected_grad_warp = [[199, 30]]
self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output,
expected_grad_data,
expected_grad_warp)
def testBatch2Height3byWidth3RGB(self):
for dtype in self.float_types:
input_shape = [2, 3, 3, 3]
input_rgb_data = [
0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1, 30, 105, 2, 40, 115,
3, 50, 125, 4, 60, 135, 5, 70, 145, 6, 0, 5, 13, 54, 135, 226, 37, 8,
234, 90, 255, 1, 30, 105, 2, 40, 115, 3, 50, 125, 4, 60, 135, 5, 70,
145, 6
]
input_np = np.array(input_rgb_data, dtype=dtype).reshape(input_shape)
# 2 batches and 2 samples for each batch.
warp_shape = [2, 2, 2]
warp_data = [0.7, 0.6, 1, 0.7, 0.9, 1.2, 1.3, 1.6]
warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape)
expected_forward = [[[43.92, 128.4, 65.86], [37.2, 114., 69.2]],
[[40.6, 122.8, 2.5], [51., 126, 4.1]]]
self._assertForwardOpMatchesExpected(input_np, warp_np, expected_forward)
expected_grad_data = [[[[0.12, 0.12, 0.12],
[0.57999998, 0.57999998, 0.57999998],
[0., 0., 0.]],
[[0.18000001, 0.18000001, 0.18000001],
[1.12, 1.12, 1.12], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0.08000001, 0.08000001, 0.08000001],
[0.99999988, 0.99999988, 0.99999988],
[0.11999997, 0.11999997, 0.11999997]],
[[0.02000001, 0.02000001, 0.02000001],
[0.60000008, 0.60000008, 0.60000008],
[0.17999998, 0.17999998, 0.17999998]]]]
expected_grad_warp = [[[33.39999008, -96.20000458], [-26.10000229,
-278.]],
[[-162.99998474, 39.99999619], [21., 63.]]]
grad_output = np.ones([2, 2, 3], dtype=dtype)
self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output,
expected_grad_data,
expected_grad_warp)
if __name__ == '__main__':
test.main()

View File

@ -358,6 +358,11 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array([[-0.05, 6.05, 5]], dtype=dtype),
expected=np.array([[0, 6, 5]], dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.leaky_relu,
np.array([[-2, -1, 0, 1, 2]], dtype=dtype),
expected=np.array([[-0.4, -0.2, 0.0, 1.0, 2.0]], dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.softmax,
np.array([1, 2, 3, 4], dtype=dtype),

View File

@ -194,6 +194,7 @@ cc_library(
":side_effect_util",
":tf2xla_util",
"//tensorflow/compiler/jit:xla_cluster_util",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",

View File

@ -178,6 +178,32 @@ tf_kernel_library(
],
)
# A separate cc_library for resampler_ops is needed because resampler is in
# contrib/, and thus the declaration of resampler cannot be pulled into the deps
# of xla_ops. Therefore, resampler_ops is its own cc_library target, and its
# corresponding tf_kernel_library is defined in contrib/resampler/BUILD.
cc_library(
name = "resampler_ops",
srcs = ["resampler_ops.cc"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
cc_library(
name = "conv_op_helpers",
srcs = ["conv_op_helpers.cc"],

View File

@ -231,20 +231,22 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
num_extended[0] = upper_padding[0] / (dims.kernel_size[0]);
num_extended[1] = upper_padding[1] / (dims.kernel_size[1]);
const int64 batch_dim_size =
builder->GetShape(input).ValueOrDie().dimensions(0);
if (num_extended[0] > 0) {
auto slice =
xla::Slice(input_data, {0, in_size[0] - 1, 0, 0},
{1, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
auto slice = xla::Slice(
input_data, {0, in_size[0] - 1, 0, 0},
{batch_dim_size, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
for (int i = 0; i < num_extended[0]; i++) {
input_data = xla::ConcatInDim(builder, {input_data, slice}, 1);
}
}
if (num_extended[1] > 0) {
auto slice =
xla::Slice(input_data, {0, 0, in_size[1] - 1, 0},
{1, in_size[0] + num_extended[0], in_size[1], channels},
{1, 1, 1, 1});
auto slice = xla::Slice(
input_data, {0, 0, in_size[1] - 1, 0},
{batch_dim_size, in_size[0] + num_extended[0], in_size[1], channels},
{1, 1, 1, 1});
for (int i = 0; i < num_extended[1]; i++) {
input_data = xla::ConcatInDim(builder, {input_data, slice}, 2);
}

View File

@ -15,14 +15,12 @@ limitations under the License.
// Native XLA implementations of XLA Relu Ops
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
namespace tensorflow {
namespace {
@ -37,6 +35,7 @@ class ReluOp : public XlaOpKernel {
ctx->SetOutput(0, xla::Max(zero, ctx->Input(0)));
}
};
REGISTER_XLA_OP(Name("Relu"), ReluOp);
class Relu6Op : public XlaOpKernel {
public:
@ -49,6 +48,22 @@ class Relu6Op : public XlaOpKernel {
ctx->SetOutput(0, xla::Clamp(zero, ctx->Input(0), six));
}
};
REGISTER_XLA_OP(Name("Relu6"), Relu6Op);
class LeakyReluOp : public XlaOpKernel {
public:
explicit LeakyReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_));
}
void Compile(XlaOpKernelContext* ctx) override {
auto features = ctx->Input("features");
auto output =
xla::Max(features, features * xla::ScalarLike(features, alpha_));
ctx->SetOutput(0, output);
}
float alpha_;
};
REGISTER_XLA_OP(Name("LeakyRelu"), LeakyReluOp);
class ReluGradOp : public XlaOpKernel {
public:
@ -64,6 +79,7 @@ class ReluGradOp : public XlaOpKernel {
ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), zero));
}
};
REGISTER_XLA_OP(Name("ReluGrad"), ReluGradOp);
class Relu6GradOp : public XlaOpKernel {
public:
@ -83,11 +99,24 @@ class Relu6GradOp : public XlaOpKernel {
ctx->SetOutput(0, out);
}
};
REGISTER_XLA_OP(Name("Relu"), ReluOp);
REGISTER_XLA_OP(Name("Relu6"), Relu6Op);
REGISTER_XLA_OP(Name("ReluGrad"), ReluGradOp);
REGISTER_XLA_OP(Name("Relu6Grad"), Relu6GradOp);
class LeakyReluGradOp : public XlaOpKernel {
public:
explicit LeakyReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_));
}
void Compile(XlaOpKernelContext* ctx) override {
auto gradients = ctx->Input("gradients");
auto features = ctx->Input("features");
auto output =
xla::Select(xla::Gt(features, xla::ScalarLike(features, 0)), gradients,
gradients * xla::ScalarLike(gradients, alpha_));
ctx->SetOutput(0, output);
}
float alpha_;
};
REGISTER_XLA_OP(Name("LeakyReluGrad"), LeakyReluGradOp);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,541 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <numeric>
#include <vector>
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
using xla::XlaOp;
// TODO(b/112295522): note that sampling from image boundary is not currently
// being handled properly.
// Calculates the bilinear weight tensor, given basis ratio (px, py) of the
// sampling position:
// W = [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py]
// 'ratio' tensor has dimensions [batch, dim_0, ...dim_n, 2].
//
// The returned tensor has dimensions [batch, dim_0, ... dim_n, 4].
XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio,
const TensorShape warp_shape,
xla::PrimitiveType xla_type) {
auto first_term = xla::ConstantR2<float>(
ctx->builder(), {{1.0, 1.0}, {0.0, 1.0}, {1.0, 0.0}, {0.0, 0.0}});
first_term = xla::ConvertElementType(first_term, xla_type);
auto warp_dims = warp_shape.dim_sizes();
std::vector<int64> broadcast_dims(warp_dims.begin(), warp_dims.end() - 1);
broadcast_dims.push_back(4);
broadcast_dims.push_back(2);
const int64 broadcast_dims_size = broadcast_dims.size();
std::vector<int64> last_two_dims_indices = {(broadcast_dims_size - 2),
(broadcast_dims_size - 1)};
xla::Shape broadcast_shape =
xla::ShapeUtil::MakeShape(xla_type, broadcast_dims);
auto broadcast_first_term =
xla::BroadcastInDim(first_term, broadcast_shape, last_two_dims_indices);
// Ratio is of the same dimension as warp, which is [batch, dim_0,... dim_n,
// 2], we broadcast ratio tensor to 'broadcast_dim' by keeping the
// [batch, dim_0,...dim_n] dimensions and the [2] dimension as the last
// dimension.
std::vector<int64> ratio_broadcast_indices(broadcast_dims.size());
std::iota(ratio_broadcast_indices.begin(), ratio_broadcast_indices.end(), 0);
ratio_broadcast_indices.erase(ratio_broadcast_indices.end() - 2);
auto broadcast_ratio =
xla::BroadcastInDim(ratio, broadcast_shape, ratio_broadcast_indices);
auto first_term_subtract_weights = broadcast_first_term - broadcast_ratio;
// Now we have [(1-px, 1-py), (-px, 1-py), (1-px, -py), (px, py)], need to
// flip the signs of the second and the third term.
auto sign_change = xla::ConstantR2<float>(
ctx->builder(), {{1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {1.0, 1.0}});
sign_change = xla::ConvertElementType(sign_change, xla_type);
auto broadcast_sign_change =
xla::BroadcastInDim(sign_change, broadcast_shape, last_two_dims_indices);
auto flipped = first_term_subtract_weights * broadcast_sign_change;
// Build up the final bilinear weight tensor by multiply reduction, which
// gives:
// [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py]
// for each 4 neighboring pixels where px and py are the weight of the target
// pixel we are sampling from.
return xla::Reduce(
flipped, xla::One(ctx->builder(), xla_type),
xla::CreateScalarMultiplyComputation(xla_type, ctx->builder()),
{broadcast_dims_size - 1});
}
// Concatenates the batch indices to the (x, y) coordinate indices.
// This is done by first creating an Iota tensor that represents the current
// batch it is in, then concatenate with the givin (coordinate) indices.
//
// The resulting tensor has dimension (batch, dim_0, ... dim_n, 3) where
// the last dimension of size 3 in turn is [batch_number, x, y].
// The [batch_number, x, y] dimension is needed because the indices
// [x,y] alone cannot allow the xla::Gather operation to gather from the input
// data, which is of dimension [batch, height(y), width(x), channel] with
// 'batch' being the first dimension.
XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices,
const TensorShape& warp_shape) {
// We need to create an iota tensor with the same batch dimension.
std::vector<int64> dimensions;
for (auto dim : warp_shape) {
dimensions.push_back(dim.size);
}
// Except the last dimension, which is of size 1.
dimensions.back() = 1;
auto batch_indices =
xla::Iota(b, xla::ShapeUtil::MakeShape(xla::U32, dimensions),
/*iota_dimension=*/0);
return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1);
}
// Gathers the 2x2 neighbors of the input starting_indices, and return a
// tensor of dimension [batch, dim_0, ... dim_n, 4, data_channels].
// 'gather_indices' is of dimension [batch, dim_0, ..., dim_n, 3] where the last
// dimension of size 3 is (batch_no, x, y).
XlaOp Gather2by2Neighbors(xla::XlaBuilder* b, XlaOp data, XlaOp gather_indices,
int64 data_channels, int warp_dims) {
xla::GatherDimensionNumbers gather_dim_numbers;
const int64 neighbor_data_dimensions = warp_dims + 2;
// Since the Gather output dimensions are [batch, dim_0, ... dim_n, 2, 2,
// data_channels], the offset dimensions for Gather is the last 3 dimensions.
gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 3);
gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 2);
gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 1);
// The last dimension of 'gather_indices' is the starting indices for gather.
gather_dim_numbers.set_index_vector_dim(warp_dims - 1);
gather_dim_numbers.add_collapsed_slice_dims(0);
gather_dim_numbers.add_start_index_map(0);
// Since input is of dimension [batch, height(y), width(x), channel], and warp
// is of dimension [batch, x, y], the ordering of x, y here needs to be
// swapped when gathering.
gather_dim_numbers.add_start_index_map(2);
gather_dim_numbers.add_start_index_map(1);
// Data dimensions are [batch, x, y, channel].
// Output dimensions are [batch, dim_0, ... dim_n, 2, 2, data_channels].
auto neighbors_data = xla::Gather(data, gather_indices, gather_dim_numbers,
/*slice_sizes=*/{1, 2, 2, data_channels});
// Collapse the ...,2,2,... dimensions into ...,4,...
return xla::Collapse(neighbors_data, {warp_dims - 1, warp_dims});
}
// Scatter 'updates' tensor to 'grad_data' based on 'indices'. Returns the
// resulting tensor of dimension: [batch, dim_0, ...dim_n, 2, 2, data_channels].
// This function can also be seen as the inverse of 'Gather2by2Neighbors'.
XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices,
XlaOp updates, int64 warp_dims,
xla::PrimitiveType xla_type) {
xla::ScatterDimensionNumbers scatter_dim_numbers;
const int64 neighbor_data_dimensions = warp_dims + 2;
// Since the Scatter output dimensions are [batch, dim_0, ... dim_n, 2, 2,
// data_channels], the update window dimensions is the last 3 dimensions.
scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 3);
scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 2);
scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 1);
scatter_dim_numbers.set_index_vector_dim(warp_dims - 1);
scatter_dim_numbers.add_inserted_window_dims(0);
scatter_dim_numbers.add_scatter_dims_to_operand_dims(0);
// Since input is of dimension [batch, height(y), width(x), channel], and warp
// is of dimension [batch, x, y], the ordering of x, y here needs to be
// swapped when scattering.
scatter_dim_numbers.add_scatter_dims_to_operand_dims(2);
scatter_dim_numbers.add_scatter_dims_to_operand_dims(1);
return xla::Scatter(grad_data, indices, updates,
xla::CreateScalarAddComputation(xla_type, ctx->builder()),
scatter_dim_numbers);
}
// Build computation the backprop into input 'data'.
// Where input:
// grad_output is of dimension [batch, dim_0, ...dim_n, channel]
// ratio is of dimension [batch, dim_0, ...dim_n, 2]
// gather_indices is of dimension [batch, dim_0, ...dim_n, 3]
//
// Output:
// scatter-add to each 2x2 grad_data neighbor:
// grad_data[fx, fy, chan] += output_grad * dx * dy
// grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy
// grad_data[fx, cy, chan] += output_grad * dx * (1 - dy)
// grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy)
// where (dx, dy) is (1 - ratio).
XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
XlaOp gather_indices, xla::PrimitiveType warp_type,
TensorShape warp_shape, int64 data_channels,
xla::Shape data_shape) {
// Weights tensor has dimension [batch, dim_0, ... dim_n, 4].
auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type);
auto warp_dims = warp_shape.dim_sizes();
std::vector<int64> warp_dims_without_last_dims(warp_dims.begin(),
warp_dims.end() - 1);
std::vector<int64> reshaped_weights_dims = warp_dims_without_last_dims;
// Reshape the last dimension of size 4 to two dimensions [2, 2].
reshaped_weights_dims.push_back(2);
reshaped_weights_dims.push_back(2);
std::vector<int64> reshape_dims(warp_shape.dims());
std::iota(reshape_dims.begin(), reshape_dims.end(), 0);
// The dimension is [batch, dim_0,..., dim_n, 2, 2].
auto reshaped_weights = xla::Reshape(weights, /*dimensions=*/reshape_dims,
/*new_sizes=*/reshaped_weights_dims);
std::vector<int64> weights_with_channels_dims = reshaped_weights_dims;
weights_with_channels_dims.push_back(data_channels);
auto weights_with_channels_shape =
xla::ShapeUtil::MakeShape(warp_type, weights_with_channels_dims);
std::vector<int64> reshaped_weights_indices(reshaped_weights_dims.size());
std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(),
0);
// The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel].
auto broadcast_reshaped_weights = xla::BroadcastInDim(
reshaped_weights, weights_with_channels_shape, reshaped_weights_indices);
std::vector<int64> grad_output_indices(warp_dims_without_last_dims.size());
std::iota(grad_output_indices.begin(), grad_output_indices.end(), 0);
grad_output_indices.push_back(weights_with_channels_dims.size() - 1);
XlaOp broadcast_grad_output = xla::BroadcastInDim(
grad_output, weights_with_channels_shape, grad_output_indices);
auto grad_output_multiply_weights =
broadcast_grad_output * broadcast_reshaped_weights;
auto grad_data = xla::ConstantLiteral(
ctx->builder(), xla::Literal::CreateFromShape(data_shape));
return ScatterToGradData(ctx, grad_data, gather_indices,
grad_output_multiply_weights, warp_shape.dims(),
warp_type);
}
// Build computation for the backprop into input 'warp'.
// Where input:
// warp is of dimension [batch, dim_0, ...dim_n, 2]
// grad_output is of dimension [batch, dim_0, ...dim_n, channel]
// ratio is of dimension [batch, dim_0, ...dim_n, 2]
// gather_indices is of dimension [batch, dim_0, ...dim_n, 3]
// data is of dimension [batch, x, y, channel]
//
// Output (simplified by ignoring the batch dimensions):
// Since the forward path has:
// output = dot(weights * neighbors)
// The backprop into warp will therefore be:
// grad_warp = output_grad * d_output / d_warp
// = output_grad * (d_weights / d_warp * neighbors + d_neighbors /
// d_warp * weight)
// Where:
// d_weights / d_warp_x = [-(1 - py), (1 - py), -py, py]
// d_weights / d_warp_y = [-(1 - px), -px, (1-px), px]
// and
// d_neighbors / d_warp_x = 0
//
// Therefore:
// grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy)
// grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy)
//
// where (px, py) is warp, (fx, fy) is the left top corner and (cx, cy) is the
// bottom right corner in a 2x2 neighborhood.
XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
XlaOp gather_indices, XlaOp data,
TensorShape warp_shape, int64 data_channels,
xla::PrimitiveType data_type) {
auto warp_dims = warp_shape.dim_sizes();
std::vector<int64> warp_dims_without_last_dims(warp_dims.begin(),
warp_dims.end() - 1);
std::vector<int64> neighbor_broadcast_dims = warp_dims_without_last_dims;
neighbor_broadcast_dims.push_back(4);
// With dimension [batch, dim_0, ...dim_n, 4]
auto neighbor_broadcast_shape =
xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims);
// The dimension is [batch, dim_0, ... dim_n, 4, data_channels]
auto neighbors_data = Gather2by2Neighbors(
ctx->builder(), data, gather_indices, data_channels, warp_shape.dims());
const int64 last_warp_dim = warp_shape.dims() - 1;
// Since we will be creating the dot product of:
// lhs: [batch, dim_0, ...dim_n, 4]
// and
// rhs: [batch, dim_0, ...dim_n, 4, data_channels]
// we choose the last dimension of lhs and the second last dimension of rhs,
// with size 4, as the contracting dimension.
xla::DotDimensionNumbers dot_dims;
for (int i = 0; i < warp_shape.dims() - 1; ++i) {
dot_dims.add_lhs_batch_dimensions(i);
dot_dims.add_rhs_batch_dimensions(i);
}
dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1);
dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1);
// img_cxcy - img_fxcy
auto bottom_right_minus_bottom_left = xla::DotGeneral(
xla::BroadcastInDim(
xla::ConvertElementType(
xla::ConstantR1<float>(ctx->builder(), {0, 0, -1, 1}), data_type),
neighbor_broadcast_shape, {last_warp_dim}),
neighbors_data, dot_dims, /*precision_config=*/nullptr);
// img_cxfy - img_fxfy
auto top_right_minus_top_left = xla::DotGeneral(
xla::BroadcastInDim(
xla::ConvertElementType(
xla::ConstantR1<float>(ctx->builder(), {-1, 1, 0, 0}), data_type),
neighbor_broadcast_shape, {last_warp_dim}),
neighbors_data, dot_dims, /*precision_config=*/nullptr);
// img_cxcy - img_cxfy
auto bottom_right_minus_top_right = xla::DotGeneral(
xla::BroadcastInDim(
xla::ConvertElementType(
xla::ConstantR1<float>(ctx->builder(), {0, -1, 0, 1}), data_type),
neighbor_broadcast_shape, {last_warp_dim}),
neighbors_data, dot_dims, /*precision_config=*/nullptr);
// img_fxcy - img_fxfy
auto bottom_left_minus_top_left = xla::DotGeneral(
xla::BroadcastInDim(
xla::ConvertElementType(
xla::ConstantR1<float>(ctx->builder(), {-1, 0, 1, 0}), data_type),
neighbor_broadcast_shape, {last_warp_dim}),
neighbors_data, dot_dims, /*precision_config=*/nullptr);
// Slice out x and y.
auto weight_x = xla::SliceInDim(ratio, /*start_index=*/0, /*limit_index=*/1,
/*stride=*/1, /*dimno=*/last_warp_dim);
auto weight_y = xla::SliceInDim(ratio, /*start_index=*/1, /*limit_index=*/2,
/*stride=*/1, /*dimno=*/last_warp_dim);
// Build 1 - y and 1 - x.
auto one_minus_y = xla::One(ctx->builder(), data_type) - weight_y;
auto one_minus_x = xla::One(ctx->builder(), data_type) - weight_x;
auto x_before_reduce =
grad_output * weight_y * bottom_right_minus_bottom_left +
one_minus_y * top_right_minus_top_left;
std::vector<int64> reshaped_sizes = warp_dims_without_last_dims;
reshaped_sizes.push_back(1);
std::vector<int64> reshaped_dims(warp_dims_without_last_dims.size());
std::iota(reshaped_dims.begin(), reshaped_dims.end(), 0);
// Reduce-add along the channel dimension.
auto x_result =
xla::Reduce(x_before_reduce, xla::Zero(ctx->builder(), data_type),
xla::CreateScalarAddComputation(data_type, ctx->builder()),
{last_warp_dim});
// Reshape before concatenating with y values.
XlaOp reshaped_x = xla::Reshape(x_result, reshaped_dims, reshaped_sizes);
auto y_before_reduce = grad_output * weight_x * bottom_right_minus_top_right +
one_minus_x * bottom_left_minus_top_left;
// Reduce-add along the channel dimension.
auto y_result =
xla::Reduce(y_before_reduce, xla::Zero(ctx->builder(), data_type),
xla::CreateScalarAddComputation(data_type, ctx->builder()),
{last_warp_dim});
XlaOp reshaped_y = xla::Reshape(y_result, reshaped_dims, reshaped_sizes);
return xla::ConcatInDim(ctx->builder(), {reshaped_x, reshaped_y},
last_warp_dim);
}
class ResamplerOp : public XlaOpKernel {
public:
explicit ResamplerOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape data_shape = ctx->InputShape("data");
OP_REQUIRES(ctx, data_shape.dims() == 4,
errors::InvalidArgument("data must be 4-dimensional",
data_shape.DebugString()));
const int64 data_channels = data_shape.dim_size(3);
xla::PrimitiveType data_type = ctx->input_xla_type(0);
TensorShape warp_shape = ctx->InputShape("warp");
OP_REQUIRES(ctx, warp_shape.dims() >= 2,
errors::InvalidArgument("warp must be at least 2-dimensional",
warp_shape.DebugString()));
for (int size : warp_shape.dim_sizes()) {
OP_REQUIRES(ctx, size > 0,
errors::InvalidArgument("warp sizes must be positive, got [",
size, "]"));
}
const int64 last_warp_dim = warp_shape.dims() - 1;
// Last dimension of warp shape must be of size 2.
OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2,
errors::InvalidArgument(
"the last dimension of warp must be exactly size 2."));
XlaOp data = ctx->Input("data");
XlaOp warp = ctx->Input("warp");
// Find the coordinates of the top left corner for the 2x2 region to be
// sampled from. The dimensions are (batch, dim_0, ... dim_n, 2) where the
// last dimension of size 2 in turn is [x, y].
XlaOp top_left = xla::ConvertElementType(warp, xla::U32);
auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape);
// The dimension is [batch, dim_0, ... dim_n, 4, data_channels]
auto neighbors_data = Gather2by2Neighbors(
ctx->builder(), data, gather_indices, data_channels, warp_shape.dims());
// Dimensions are [batch, dim_0, ... dim_n, 2].
XlaOp ratio = warp - xla::ConvertElementType(top_left, data_type);
// Obtain the bilinear blending weights, the dimension is [batch, dim_0,
// ...dim_n, 4].
auto weights = BilinearWeights(ctx, ratio, warp_shape, data_type);
// Since we will be creating the dot product of:
// lhs: [batch, dim_0, ...dim_n, 4]
// and
// rhs: [batch, dim_0, ...dim_n, 4, data_channels]
// we choose the last dimension of lhs and the second last dimension of rhs,
// with size 4, as the contracting dimension.
xla::DotDimensionNumbers dot_dims;
for (int i = 0; i < warp_shape.dims() - 1; ++i) {
dot_dims.add_lhs_batch_dimensions(i);
dot_dims.add_rhs_batch_dimensions(i);
}
dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1);
dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1);
auto blended_pixels = xla::DotGeneral(weights, neighbors_data, dot_dims,
/*precision_config=*/nullptr);
ctx->SetOutput(0, blended_pixels);
}
};
REGISTER_XLA_OP(Name("Resampler"), ResamplerOp);
class ResamplerGradOp : public XlaOpKernel {
public:
explicit ResamplerGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
DataType output_dtype;
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype));
}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape data_shape_tf = ctx->InputShape("data");
OP_REQUIRES(ctx, data_shape_tf.dims() == 4,
errors::InvalidArgument("data must be 4-dimensional",
data_shape_tf.DebugString()));
const int64 data_channels = data_shape_tf.dim_size(3);
xla::PrimitiveType data_type = ctx->input_xla_type(0);
TensorShape warp_shape = ctx->InputShape("warp");
OP_REQUIRES(ctx, warp_shape.dims() >= 2,
errors::InvalidArgument("warp must be at least 2-dimensional",
warp_shape.DebugString()));
for (int size : warp_shape.dim_sizes()) {
OP_REQUIRES(ctx, size > 0,
errors::InvalidArgument("warp sizes must be positive, got [",
size, "]"));
}
// Last dimension of warp shape must be of size 2.
OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2,
errors::InvalidArgument(
"the last dimension of warp must be exactly size 2."));
xla::PrimitiveType warp_type = ctx->input_xla_type(1);
TensorShape output_grad_shape = ctx->InputShape("grad_output");
OP_REQUIRES(
ctx, output_grad_shape.dims() >= 2,
errors::InvalidArgument("output_grad must be at least 2-dimensional",
output_grad_shape.DebugString()));
// Dimensions are [batch, x, y, channel].
XlaOp data = ctx->Input("data");
xla::Shape data_shape = TensorShapeToXLAShape(data_type, data_shape_tf);
// Dimensions are [batch, dim_0, ...dim_n, 2].
XlaOp warp = ctx->Input("warp");
// Dimensions are [batch, dim_0, ...dim_n, channel].
XlaOp grad_output = ctx->Input("grad_output");
// Find the top left corner coordinate for the region to be sampled from.
// The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension
// of size 2 in turn is [x, y].
XlaOp top_left = xla::ConvertElementType(warp, xla::U32);
// Dimensions are [batch, dim_0, ... dim_n, 2]
XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type);
// Indices for gathering neighboring pixels.
auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape);
auto grad_data =
CalculateGradData(ctx, grad_output, ratio, gather_indices, warp_type,
warp_shape, data_channels, data_shape);
auto grad_warp =
CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data,
warp_shape, data_channels, data_type);
ctx->SetOutput(0, grad_data);
ctx->SetOutput(1, grad_warp);
}
};
REGISTER_XLA_OP(Name("ResamplerGrad"), ResamplerGradOp);
} // namespace
} // namespace tensorflow

View File

@ -17,8 +17,10 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
@ -61,113 +63,79 @@ class ReverseSequenceOp : public XlaOpKernel {
const auto seq_lens = context->Input(1);
const int64 batch_size = input_shape.dim_size(batch_dim_);
if (batch_size == 0) {
context->SetOutput(0, input);
return;
}
const DataType input_type = context->input_type(0);
const DataType seq_lens_type = context->input_type(1);
// Given the input
//
// 012345
// 6789AB
//
// and sequence lens {2, 3} we:
//
// 1. Reverse and pad each row to get
//
// 543210XXXXXX
// BA9876XXXXXX
//
// 2. Gather out the suffix from each row to get
//
// 10XXXX
// 876XXX
//
// 3. Select from the input and the array created by (2) to get the result.
//
// 102345
// 8769AB
const xla::PrimitiveType input_type = context->input_xla_type(0);
const xla::PrimitiveType seq_lens_type = context->input_xla_type(1);
const int64 max_seq_len = input_shape.dim_size(seq_dim_);
xla::Shape input_xla_shape;
OP_REQUIRES_OK(context, TensorShapeToXLAShape(input_type, input_shape,
&input_xla_shape));
xla::Shape seq_lens_xla_shape;
OP_REQUIRES_OK(context, TensorShapeToXLAShape(seq_lens_type, seq_lens_shape,
&seq_lens_xla_shape));
xla::XlaOp rev = xla::Rev(input, {seq_dim_});
const auto tuple_shape = xla::ShapeUtil::MakeTupleShape({
xla::ShapeUtil::MakeShape(seq_lens_xla_shape.element_type(), {}),
seq_lens_xla_shape,
input_xla_shape,
});
auto padding_config = xla::MakeNoPaddingConfig(input_shape.dims());
padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high(
max_seq_len);
xla::XlaOp padded =
xla::Pad(rev, xla::Zero(builder, input_type), padding_config);
// For each entry in the batch, reverse the sequence.
// TODO(b/65689298): generalize the Map() operator to non-scalar cases and
// use it here, instead of a While loop.
// Form a start indices tensor with shape [2, batch_size]. For each batch
// entry we have a (batch offset, seq offset) pair.
xla::XlaOp start_indices = xla::ConcatInDim(
builder,
{
xla::Iota(builder,
xla::ShapeUtil::MakeShape(seq_lens_type, {1, batch_size}),
/*iota_dimension=*/1),
xla::Reshape(xla::ScalarLike(seq_lens, max_seq_len) - seq_lens,
{1, batch_size}),
},
/*dimension=*/0);
// Condition: lambda (i, _, _): i < batch_size
auto condition_builder =
builder->CreateSubBuilder("reverse_sequence_condition");
{
auto param =
xla::Parameter(condition_builder.get(), 0, tuple_shape, "param");
auto i = xla::GetTupleElement(param, 0);
xla::Lt(i, XlaHelpers::IntegerLiteral(condition_builder.get(),
seq_lens_type, batch_size));
xla::GatherDimensionNumbers dnums;
// The first dimension of start_indices contains the batch/seq dim choice.
dnums.set_index_vector_dim(0);
dnums.add_start_index_map(batch_dim_);
dnums.add_start_index_map(seq_dim_);
// All other dimensions other than the batch dim are offset dimensions.
for (int i = 0; i < input_shape.dims(); ++i) {
if (i != batch_dim_) {
dnums.add_offset_dims(i);
}
}
auto condition = condition_builder->Build();
OP_REQUIRES_OK(context, condition.status());
dnums.add_collapsed_slice_dims(batch_dim_);
auto body_builder = builder->CreateSubBuilder("reverse_sequence_body");
{
auto param = xla::Parameter(body_builder.get(), 0, tuple_shape, "param");
auto i = xla::GetTupleElement(param, 0);
auto seq_lens = xla::GetTupleElement(param, 1);
auto output = xla::GetTupleElement(param, 2);
auto slice_sizes = input_shape.dim_sizes();
slice_sizes[batch_dim_] = 1;
// seq_len is the sequence length of the current batch element (rank 1)
auto seq_len = xla::DynamicSlice(seq_lens, xla::Reshape(i, {1}), {1});
xla::XlaOp output = xla::Gather(padded, start_indices, dnums, slice_sizes);
// Indices is the offset of the batch element in the input.
auto batch_element_indices =
xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type),
{input_shape.dims()});
batch_element_indices = xla::DynamicUpdateSlice(
batch_element_indices, xla::Reshape(i, {1}),
xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(),
seq_lens_type, batch_dim_),
{1}));
// Slice out the current batch element and pad it out in the sequence
// dimension.
TensorShape slice_shape = input_shape;
slice_shape.set_dim(batch_dim_, 1);
slice_shape.set_dim(seq_dim_, max_seq_len);
auto slice = xla::DynamicSlice(output, batch_element_indices,
slice_shape.dim_sizes());
auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims());
padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high(
slice_shape.dim_size(seq_dim_));
slice = xla::Pad(slice, XlaHelpers::Zero(body_builder.get(), input_type),
padding_config);
// Now slice out the reversed sequence from its actual start.
// sequence_start_indices is the offset of the start of the reversed
// sequence in the input. The slice will go into the padding, however, we
// will mask off these elements and replace them with elements from the
// original input so their values do not matter.
auto sequence_start_indices =
xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type),
{slice_shape.dims()});
sequence_start_indices = xla::DynamicUpdateSlice(
sequence_start_indices,
xla::Sub(XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
max_seq_len),
seq_len),
xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(),
seq_lens_type, seq_dim_),
{1}));
slice = xla::DynamicSlice(slice, sequence_start_indices,
slice_shape.dim_sizes());
// Shift the reversed sequence to the left.
output = xla::DynamicUpdateSlice(output, slice, batch_element_indices);
xla::Tuple(
body_builder.get(),
{xla::Add(i, XlaHelpers::One(body_builder.get(), seq_lens_type)),
seq_lens, output});
}
auto body = body_builder->Build();
OP_REQUIRES_OK(context, body.status());
auto loop_output = xla::While(
condition.ValueOrDie(), body.ValueOrDie(),
xla::Tuple(builder, {XlaHelpers::Zero(builder, seq_lens_type), seq_lens,
xla::Rev(input, {seq_dim_})}));
auto output = xla::GetTupleElement(loop_output, 2);
// Mask out elements after the sequence length.
xla::XlaOp iota =
xla::Iota(builder, seq_lens_xla_shape.element_type(), max_seq_len);
// Mask out elements after the sequence length, and copy the corresponding
// elements from the input.
xla::XlaOp iota = xla::Iota(builder, seq_lens_type, max_seq_len);
std::vector<int64> dims(input_shape.dims(), 1);
dims[batch_dim_] = batch_size;
auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_});

View File

@ -126,7 +126,9 @@ class StackOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
};
REGISTER_XLA_OP(Name("StackV2").CompileTimeConstantInput("max_size"), StackOp);
REGISTER_XLA_OP(
Name("StackV2").CompileTimeConstantInput("max_size").CompilationOnly(),
StackOp);
class StackPushOp : public XlaOpKernel {
public:
@ -173,7 +175,7 @@ class StackPushOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp);
};
REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp);
REGISTER_XLA_OP(Name("StackPushV2").CompilationOnly(), StackPushOp);
class StackPopOp : public XlaOpKernel {
public:
@ -227,7 +229,7 @@ class StackPopOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp);
};
REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp);
REGISTER_XLA_OP(Name("StackPopV2").CompilationOnly(), StackPopOp);
class StackCloseOp : public XlaOpKernel {
public:
@ -241,7 +243,7 @@ class StackCloseOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp);
};
REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp);
REGISTER_XLA_OP(Name("StackCloseV2").CompilationOnly(), StackCloseOp);
} // anonymous namespace
} // namespace tensorflow

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
@ -75,6 +76,222 @@ Status CheckFeedFetchNameConflicts(const string& kind,
return Status::OK();
}
// For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to
// `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`.
Status CopyAssociatedFunctions(Graph* g,
const FunctionLibraryDefinition* lookup_fld,
FunctionLibraryDefinition* fld) {
for (Node* n : g->op_nodes()) {
for (const auto& associated_function :
GetAssociatedFunctions(*n, lookup_fld)) {
switch (associated_function.type()) {
case AssociatedFunctionInfo::kFunctionCallNode: {
const FunctionDef* fdef =
lookup_fld->Find(associated_function.func_name());
if (!fdef) {
return errors::Internal(
"Cannot find function ", associated_function.func_name(),
" for function call node ", n->DebugString());
}
TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef));
break;
}
case AssociatedFunctionInfo::kSymbolicGradient:
case AssociatedFunctionInfo::kFunctionAttr:
break;
}
}
}
return Status::OK();
}
// For graph `g`, replaces _Arg nodes whose "index" attribute is in
// `const_input_index_to_node` with Const nodes.
Status ReplaceArgUsageWithConstNode(
Graph* g,
const std::unordered_map<int, const Node*>& const_input_index_to_node) {
// Collect all _Arg nodes.
std::unordered_map<int, Node*> arg_nodes;
for (Node* n : g->op_nodes()) {
if (n->type_string() == FunctionLibraryDefinition::kArgOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
arg_nodes[index] = n;
}
}
for (const auto& iter : const_input_index_to_node) {
int arg_index = iter.first;
Node* const_node = g->CopyNode(iter.second);
Node* arg_node = arg_nodes[arg_index];
// Collect all usages of the _Arg node.
struct OutEdgeInfo {
int dst_node_id, dst_input;
};
std::vector<OutEdgeInfo> usages;
for (const Edge* e : arg_node->out_edges()) {
if (e->IsControlEdge()) {
continue;
}
usages.push_back({e->dst()->id(), e->dst_input()});
}
for (int i = 0; i < usages.size(); i++) {
// Make a copy of `usage_node`, and change its input to const node.
Node* usage_node = g->FindNodeId(usages[i].dst_node_id);
NodeDef replace_def = usage_node->def();
*replace_def.mutable_input(usages[i].dst_input) = const_node->name();
TF_ASSIGN_OR_RETURN(Node * replace_node,
ReplaceNode(g, usage_node, replace_def));
const Edge* usage_edge;
TF_RETURN_IF_ERROR(
replace_node->input_edge(usages[i].dst_input, &usage_edge));
g->RemoveEdge(usage_edge);
g->AddEdge(const_node, 0, replace_node, usages[i].dst_input);
// Later entries in `usages` might have `usage_node` as dst node, but
// `usage_node` is removed. Replace such entries with `replace_node`.
for (int j = i + 1; j < usages.size(); j++) {
if (usages[j].dst_node_id == usages[i].dst_node_id) {
usages[j].dst_node_id = replace_node->id();
}
}
}
}
return Status::OK();
}
// For a node's function attr (e.g. then/else branch for "If" nodes), rewrites
// the function to replace _Arg nodes in `const_input_index_to_node` with Const
// inputs.
Status PropagateConstIntoFuncAttr(
Node* n, const string& attr_name,
const std::unordered_map<int, const Node*>& const_input_index_to_node,
const FunctionLibraryDefinition* lookup_fld,
FunctionLibraryDefinition* fld) {
// Instantiate the function.
NameAttrList func_attr;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr));
const FunctionDef* fdef = lookup_fld->Find(func_attr.name());
if (!fdef) {
return errors::Internal("Cannot find function ", func_attr.name(),
" for node ", n->name());
}
FunctionBody* fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
*fdef, AttrSlice(&func_attr.attr()), lookup_fld,
[lookup_fld](const string& op, const OpDef** sig) {
return lookup_fld->LookUpOpDef(op, sig);
},
&fbody));
std::unique_ptr<FunctionBody> fbody_deleter(fbody);
// Rewrite _Arg usages with Const node.
Graph* func_graph = fbody->graph;
TF_RETURN_IF_ERROR(
ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node));
// Save rewritten function.
FunctionDef replace_fdef;
string new_func_name =
fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_"));
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef));
TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
// Change the node to use rewritten function.
func_attr.set_name(new_func_name);
n->ClearAttr(attr_name);
n->AddAttr(attr_name, func_attr);
// Copy associated functions.
TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld));
return Status::OK();
}
// For an "If" node in graph `g`, if it has Const node inputs, rewrite its
// then/else branch function to replace _Arg nodes with those Const inputs.
Status PropagateConstIntoIfNode(Graph* g, Node* if_node,
const FunctionLibraryDefinition* lookup_fld,
FunctionLibraryDefinition* fld) {
// Notice that first input for If node is predicate; other inputs are function
// inputs.
std::unordered_map<int, const Node*> const_input_index_to_node;
for (int i = 1; i < if_node->num_inputs(); i++) {
const Node* input_node;
TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node));
if (input_node->type_string() == "Const") {
const_input_index_to_node[i - 1] = input_node;
}
}
if (const_input_index_to_node.empty()) {
return Status::OK();
}
// Rewrite "then_branch" and "else_branch" function, replace usage of those
// _Arg nodes with corresponding const node.
for (const auto& attr_name :
std::vector<string>{"then_branch", "else_branch"}) {
TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
if_node, attr_name, const_input_index_to_node, lookup_fld, fld));
}
return Status::OK();
}
// For a "While" node in graph `g`, if it has Const node inputs, rewrite its
// cond/body function to replace _Arg nodes with those Const inputs.
Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
const FunctionLibraryDefinition* lookup_fld,
FunctionLibraryDefinition* fld) {
// For "While" node, we should only replace _Arg nodes which are loop
// invariants. For such _Arg nodes, the return value's input will come
// directly from the corresponding arg.
std::unordered_map<int, const Node*> const_input_index_to_node;
NameAttrList body_attr;
TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr));
const FunctionDef* body_func = lookup_fld->Find(body_attr.name());
if (!body_func) {
return errors::Internal("Cannot find body function ", body_attr.name(),
" for While node ", while_node->name());
}
for (int i = 0; i < while_node->num_inputs(); i++) {
const Node* input_node;
TF_RETURN_IF_ERROR(while_node->input_node(i, &input_node));
if (input_node->type_string() != "Const") {
continue;
}
// Check if i-th retval's input comes from i-th arg directly.
const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i);
auto output_arg_input = body_func->ret().find(output_arg.name());
if (output_arg_input == body_func->ret().end()) {
return errors::Internal("Cannot find input for output arg ",
output_arg.name(), " in function ",
body_attr.name());
}
const OpDef_ArgDef& input_arg = body_func->signature().input_arg(i);
if (output_arg_input->second != input_arg.name()) {
continue;
}
const_input_index_to_node[i] = input_node;
}
if (const_input_index_to_node.empty()) {
return Status::OK();
}
// Rewrite "cond" and "body" function, replace usage of those _Arg nodes with
// corresponding const node.
for (const auto& attr_name : std::vector<string>{"cond", "body"}) {
TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
while_node, attr_name, const_input_index_to_node, lookup_fld, fld));
}
return Status::OK();
}
} // namespace
const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
@ -520,4 +737,17 @@ xla::StatusOr<Node*> BuildIdentityNode(
return id_node;
}
Status PropagateConstIntoFunctionalNodes(
Graph* g, const FunctionLibraryDefinition* lookup_fld,
FunctionLibraryDefinition* fld) {
for (Node* n : g->op_nodes()) {
if (n->type_string() == "If") {
TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
} else if (n->type_string() == "While") {
TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld));
}
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -183,6 +183,20 @@ xla::StatusOr<Node*> BuildIdentityNode(Graph* graph, const string& node_name,
DataType dtype, const Node* input,
absl::optional<string> requested_device);
// For "If"/"While" nodes, if some of their inputs are Const nodes, rewrite
// body functions to use the Const nodes instead of original _Arg nodes.
//
// For example, say we have the following computation:
// shape = constant_op.constant([1])
// return tf.cond(pred, lambda: tf.ones(shape), lambda: tf.zeros(shape))
// If we do not rewrite then/else function, they will use _Arg node as shape
// input for tf.ones/tf.zeros. But XLA requires that shape input to be compile
// time constant, so XLA compilation will fail. This rewriting process will
// change the shape input to Const node.
Status PropagateConstIntoFunctionalNodes(
Graph* g, const FunctionLibraryDefinition* lookup_fld,
FunctionLibraryDefinition* fld);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_

View File

@ -756,6 +756,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
CompilationResult* result) {
VLOG(1) << "Executing graph symbolically to populate XlaBuilder.";
TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
graph.get(), options_.flib_def, local_flib_def_.get()));
if (VLOG_IS_ON(2)) {
VLOG(2) << "XlaCompiler::CompileGraph: "
<< dump_graph::DumpGraphToFile(

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <functional>
#include <memory>
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
@ -129,21 +130,27 @@ XlaOpRegistry::~XlaOpRegistry() = default;
// Lazily register the CPU and GPU JIT devices the first time
// GetCompilationDevice is called.
static void* registration_init = [&registry]() {
legacy_flags::MarkForCompilationPassFlags* flags =
legacy_flags::GetMarkForCompilationPassFlags();
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
mutex_lock lock(registry.mutex_);
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) {
DeviceRegistration& registration =
registry.compilation_devices_[DEVICE_CPU];
registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
registration.requires_compilation = false;
registration.enable_jit_by_default = false;
registration.autoclustering_policy =
cpu_global_jit
? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally
: XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested;
registration.compile_resource_ops = false;
}
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) {
DeviceRegistration& registration =
registry.compilation_devices_[DEVICE_GPU];
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.requires_compilation = false;
registration.enable_jit_by_default = true;
registration.autoclustering_policy =
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally;
registration.compile_resource_ops = false;
}
return nullptr;

View File

@ -66,19 +66,26 @@ class XlaOpRegistry {
public:
typedef OpKernel* (*Factory)(OpKernelConstruction*);
enum class AutoclusteringPolicy {
// Enable autoclustering if the user requests it, e.g., via
// experimental_jit_scope. Does not autocluster if the JIT is enabled
// globally (e.g., via the OptimizerOptions in the TF session
// configuration.)
kIfExplicitlyRequested,
// Enable autoclustering if explicitly requested, or if the JIT is enabled
// globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N.
kIfEnabledGlobally,
// Always try to autocluster ops placed on this device.
kAlways,
};
// Describes how to compile operators assigned to a device.
struct DeviceRegistration {
// The name of the an XLA compilation device to use to compile code.
string compilation_device_name;
// Do operators assigned to this device require compilation?
bool requires_compilation;
// If !requires_compilation, should we try to JIT operators on this device
// when XLA JIT compilation is enabled globally via the SessionOptions?
// (It is still possible to explicitly mark operators to JIT compile, even
// if enable_jit_by_default is false.)
bool enable_jit_by_default;
// When should we autocluster operators assigned to this device?
AutoclusteringPolicy autoclustering_policy;
// Enable compilation of operators that use DT_RESOURCE types?
bool compile_resource_ops = false;

View File

@ -7,6 +7,7 @@ package_group(
packages = [
"//tensorflow/compiler/...",
"//tensorflow/contrib/tpu/...",
"//third_party/py/jax/...",
],
)

View File

@ -1,7 +1,6 @@
<p align="center">
<img width="200" src="xlalogo.png"/>
<img width="200" src="./g3doc/images/xlalogo.png"/>
</p>
XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear
algebra that optimizes TensorFlow computations. See the
[documentation](https://www.tensorflow.org/performance/xla/) for more details.
algebra that optimizes TensorFlow computations. See the [documentation](./g3doc/overview.md).

View File

@ -1,3 +1,3 @@
# XLA: Accelerated Linear Algebra
These are the docs for: https://www.tensorflow.org/extend/xla
These are the docs for: https://www.tensorflow.org/xla

View File

@ -0,0 +1,29 @@
upper_tabs:
# Tabs left of dropdown menu
- include: /_upper_tabs_left.yaml
- include: /api_docs/_upper_tabs_api.yaml
# Dropdown menu
- name: Ecosystem
path: /ecosystem
is_default: true
menu:
- include: /ecosystem/_menu_toc.yaml
lower_tabs:
# Subsite tabs
other:
- name: Guide
contents:
- title: XLA overview
path: /xla/overview
- title: Broadcasting semantics
path: /xla/broadcasting
- title: Developing a new backend for XLA
path: /xla/developing_new_backend
- title: Using JIT compilation
path: /xla/jit
- title: Operation semantics
path: /xla/operation_semantics
- title: Shapes and layout
path: /xla/shapes
- title: Using AOT compilation
path: /xla/tfcompile

View File

@ -0,0 +1,35 @@
book_path: /xla/_book.yaml
project_path: /xla/_project.yaml
description: <!--no description-->
landing_page:
custom_css_path: /site-assets/css/style.css
rows:
- heading: XLA is a compiler that optimizes TensorFlow computations.
items:
- classname: devsite-landing-row-50
description: >
XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear
algebra that optimizes TensorFlow computations. The results are
improvements in speed, memory usage, and portability on server and mobile
platforms. The XLA framework is experimental and in active development.
For details, read the <a href="./overview">XLA guide</a>.
- classname: devsite-landing-row-cards
items:
- heading: XLA - TensorFlow, compiled
image_path: /ecosystem/images/tf-logo-card-16x9.png
path: https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html
buttons:
- label: Read on Google Developers blog
path: https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html
- heading: XLA at the Dev Summit
youtube_id: kAOanJczHA0
buttons:
- label: Watch the video
path: https://www.youtube.com/watch?v=kAOanJczHA0
- heading: XLA on GitHub
image_path: /ecosystem/images/github-card-16x9.png
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla
buttons:
- label: View on GitHub
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla

View File

@ -1,6 +1,6 @@
name: XLA
breadcrumb_name: XLA
home_url: /extend/xla
home_url: /xla/
parent_project_metadata_path: /_project.yaml
description: >
XLA is a compiler-based linear algebra execution engine.

View File

@ -1,16 +0,0 @@
toc:
- heading: XLA
- title: XLA overview
path: /extend/xla/
- title: Broadcasting semantics
path: /extend/xla/broadcasting
- title: Developing a new backend for XLA
path: /extend/xla/developing_new_backend
- title: Using JIT compilation
path: /extend/xla/jit
- title: Operation semantics
path: /extend/xla/operation_semantics
- title: Shapes and layout
path: /extend/xla/shapes
- title: Using AOT compilation
path: /extend/xla/tfcompile

View File

@ -29,8 +29,6 @@ namespace xla {
/* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex(
const Shape& shape, absl::Span<const int64> multi_index) {
DCHECK_EQ(shape.dimensions_size(), multi_index.size());
// Padding and nested layouts not supported yet.
DCHECK_EQ(0, shape.layout().padded_dimensions_size());
for (size_t i = 0; i < multi_index.size(); ++i) {
DCHECK_GE(multi_index[i], 0);
@ -94,8 +92,6 @@ namespace xla {
/* static */ std::vector<int64> IndexUtil::LinearIndexToMultidimensionalIndex(
const Shape& shape, int64 linear_index) {
// Padding and nested layouts not supported yet.
DCHECK_EQ(0, shape.layout().padded_dimensions_size());
DCHECK_GE(linear_index, 0);
DCHECK_LT(linear_index, ShapeUtil::ElementsIn(shape));
@ -133,18 +129,12 @@ namespace xla {
/* static */ int64 IndexUtil::GetDimensionStride(const Shape& shape,
int64 dimension) {
int64 pdim_size = LayoutUtil::PaddedDimensions(shape).size();
int64 stride = 1;
DCHECK(pdim_size == 0 || pdim_size == shape.dimensions_size());
for (auto dim : LayoutUtil::MinorToMajor(shape)) {
if (dim == dimension) {
break;
}
if (pdim_size == 0) {
stride *= shape.dimensions(dim);
} else {
stride *= LayoutUtil::PaddedDimension(shape, dim);
}
stride *= shape.dimensions()[dim];
}
return stride;
}

View File

@ -61,8 +61,7 @@ class IndexUtil {
static bool BumpIndices(const Shape& shape, absl::Span<int64> indices);
// Calculates the stride size (in number of elements, not byte size) of a
// given logical shape dimension (from 0 to rank-1). If available, padded
// dimensions are used.
// given logical shape dimension (from 0 to rank-1).
// Example:
// GetDimensionStride(F32[5,8,10,4]{3,2,1,0}, 1) ==
// sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10

View File

@ -201,8 +201,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
if (!ShapeUtil::IsArray(shape)) {
if (layout.minor_to_major_size() != 0 ||
layout.padded_dimensions_size() != 0) {
if (layout.minor_to_major_size() != 0) {
return InvalidArgument(
"shape of primitive type %s should not have a non-trivial layout",
PrimitiveType_Name(shape.element_type()));
@ -241,28 +240,6 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
dimensions_in_layout[dim] = true;
}
if (layout.padded_dimensions_size() > 0) {
if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) {
return InvalidArgument(
"layout has %d padded dimensions, but shape is rank %d",
layout.padded_dimensions_size(), ShapeUtil::Rank(shape));
}
for (int i = 0; i < layout.padded_dimensions_size(); ++i) {
if (layout.padded_dimensions(i) < shape.dimensions(i)) {
return InvalidArgument(
"for dimension %d, dimension padding (%d) is smaller than "
"the dimension size (%d) of the shape",
i, layout.padded_dimensions(i), shape.dimensions(i));
}
}
}
}
if (layout.format() == SPARSE) {
if (!layout.padded_dimensions().empty()) {
return InvalidArgument("Sparse layout has padded dimensions");
}
}
return Status::OK();
@ -303,38 +280,6 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
layout.minor_to_major().end(), std::greater<int64>());
}
/* static */ bool LayoutUtil::IsPadded(const Shape& shape) {
if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) ||
shape.layout().padded_dimensions_size() == 0) {
return false;
}
CHECK(IsDenseArray(shape)) << shape.ShortDebugString();
CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size());
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) {
return true;
}
}
return false;
}
/* static */ absl::Span<const int64> LayoutUtil::PaddedDimensions(
const Shape& shape) {
CHECK(IsDenseArray(shape));
return AsInt64Slice(shape.layout().padded_dimensions());
}
/* static */ int64 LayoutUtil::PaddedDimension(const Shape& shape,
int64 index) {
CHECK(IsDenseArray(shape));
return shape.layout().padded_dimensions(index);
}
/* static */ PaddingValue LayoutUtil::GetPaddingValue(const Shape& shape) {
CHECK(IsDenseArray(shape));
return shape.layout().padding_value();
}
/* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) {
return ShapeUtil::IsArray(shape) && shape.has_layout() &&
IsSparse(shape.layout());
@ -513,13 +458,6 @@ std::ostream& operator<<(std::ostream& out, const Layout& layout) {
for (int64 minor_to_major : layout.minor_to_major()) {
hash_value = Hash64Combine(hash_value, hash<int64>()(minor_to_major));
}
for (int64 padded_dim : layout.padded_dimensions()) {
hash_value = Hash64Combine(hash_value, hash<int64>()(padded_dim));
}
hash_value =
Hash64Combine(hash_value, hash<PaddingValue>()(layout.padding_value()));
hash_value = Hash64Combine(hash_value, layout.max_sparse_elements());
return hash_value;

View File

@ -104,23 +104,6 @@ class LayoutUtil {
// more minor, and so on until dimension N-1 which is the minor.
static bool IsMonotonicWithDim0Major(const Layout& layout);
// Returns whether the layout of the given shape has padding (a
// padded_dimension value in Layout is greater than the corresponding
// dimension size).
static bool IsPadded(const Shape& shape);
// Returns the padded_dimensions array for the given Shape. Requires that the
// shape is an array and has a dense layout.
static absl::Span<const int64> PaddedDimensions(const Shape& shape);
// Returns the given index of the padded_dimensions array for the given Shape.
// Requires that the shape is an array and has a dense layout.
static int64 PaddedDimension(const Shape& shape, int64 index);
// Returns the padding_value for the given Shape. Requires that the shape is
// an array and has a dense layout.
static PaddingValue GetPaddingValue(const Shape& shape);
// Returns whether the given Shape is an array (i.e. not a tuple) and has a
// sparse format layout.
static bool IsSparseArray(const Shape& shape);

View File

@ -304,30 +304,6 @@ TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) {
shape.tuple_shapes(1).layout()));
}
TEST_F(LayoutUtilTest, IsPadded) {
Shape shape_without_layout = ShapeUtil::MakeShape(F32, {2, 3, 4});
LayoutUtil::ClearLayout(&shape_without_layout);
EXPECT_FALSE(LayoutUtil::IsPadded(shape_without_layout));
Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4});
LayoutUtil::SetToDefaultLayout(&shape_with_layout);
EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_layout));
// Add padding equal to the dimension sizes. In this case the padding is a
// nop.
Shape shape_with_degenerate_padding = ShapeUtil::MakeShape(F32, {2, 3, 4});
shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(2);
shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(3);
shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(4);
EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_degenerate_padding));
Shape shape_with_padding = ShapeUtil::MakeShape(F32, {2, 3, 4});
shape_with_padding.mutable_layout()->add_padded_dimensions(2);
shape_with_padding.mutable_layout()->add_padded_dimensions(14);
shape_with_padding.mutable_layout()->add_padded_dimensions(42);
EXPECT_TRUE(LayoutUtil::IsPadded(shape_with_padding));
}
TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) {
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
LayoutUtil::GetDefaultLayoutForR2()));

View File

@ -1075,12 +1075,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
auto element_to_string = [&](absl::Span<const int64> indices) -> string {
PrimitiveType element_type = subshape.element_type();
if (element_type == PRED) {
// We display predicates in a densely packed form.
return literal.Get<bool>(indices, shape_index) ? "1" : "0";
}
return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
literal.GetAsString(indices, shape_index);
// We display predicates as 0s and 1s so that the string is more dense.
string elem = element_type == PRED
? literal.Get<bool>(indices, shape_index) ? "1" : "0"
: literal.GetAsString(indices, shape_index);
return ((!indices.empty() && indices.back() > 0) ? ", " : "") + elem;
};
if (ShapeUtil::Rank(subshape) == 0) {

View File

@ -34,16 +34,22 @@ namespace xla {
namespace literal_comparison {
namespace {
// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
// able to transparently access the raw 16-bit value contained within.
template <typename T>
T GetRawValue(T val) {
return val;
}
uint16 GetRawValue(Eigen::half val) { return val.x; }
// Helper function for comparing a floating point type, FloatT, bitwise equal
// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
// -- on miscompare, a nice error message is given in the AssertionFailure.
template <typename FloatT, typename UnsignedT>
Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
absl::Span<const int64> multi_index) {
// TODO(b/118627822): These are unsafe bit_casts because Eigen::Half is not
// trivially copyable.
auto ulhs = absl::bit_cast<UnsignedT>(lhs);
auto urhs = absl::bit_cast<UnsignedT>(rhs);
auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
auto lhs_double = static_cast<double>(lhs);
auto rhs_double = static_cast<double>(rhs);
if (ulhs != urhs) {

View File

@ -133,7 +133,7 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
TEST_F(LiteralUtilTest, LiteralVectorToString) {
auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
EXPECT_EQ("{101}", pred_vec.ToString());
EXPECT_EQ("{1, 0, 1}", pred_vec.ToString());
}
TEST_F(LiteralUtilTest, R2ToString) {

View File

@ -184,8 +184,9 @@ StatusOr<LocalShapedBuffer*> LocalShapedBufferTuple::Release(int i) {
int64 LocalShapedBufferTuple::size() const { return elements_.size(); }
XrtAllocation::XrtAllocation(int64 handle, Shape shape)
: handle_(handle), shape_(shape) {}
XrtAllocation::XrtAllocation(int64 handle, Shape shape,
const string& session_target)
: handle_(handle), shape_(shape), session_target_(session_target) {}
XrtAllocation::~XrtAllocation() {
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
@ -198,7 +199,7 @@ XrtAllocation::~XrtAllocation() {
return;
}
tensorflow::ClientSession session(root, "local");
tensorflow::ClientSession session(root, session_target_);
tensorflow::ClientSession::FeedType inputs;
inputs.insert({allocation_handle, handle()});
std::vector<tensorflow::Tensor> outputs;
@ -210,7 +211,8 @@ XrtAllocation::~XrtAllocation() {
}
/* static */
StatusOr<XrtAllocation*> XrtAllocation::FromLiteral(const Literal& argument) {
StatusOr<XrtAllocation*> XrtAllocation::FromLiteral(
const Literal& argument, const string& session_target) {
xrt::XLAAllocation alloc;
alloc.set_device_ordinal(0);
*alloc.mutable_value() = argument.ToProto();
@ -221,14 +223,14 @@ StatusOr<XrtAllocation*> XrtAllocation::FromLiteral(const Literal& argument) {
auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string);
TF_RETURN_IF_ERROR(root.status());
tensorflow::ClientSession session(root, "local");
tensorflow::ClientSession session(root, session_target);
tensorflow::ClientSession::FeedType inputs;
inputs.insert({literal_string, alloc.SerializeAsString()});
std::vector<tensorflow::Tensor> outputs;
TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs));
int64 handle = outputs[0].scalar<int64>()();
return new XrtAllocation(handle, argument.shape());
return new XrtAllocation(handle, argument.shape(), session_target);
}
const int64 XrtAllocation::handle() const { return handle_; }
@ -242,7 +244,7 @@ StatusOr<Literal> XrtAllocation::ToLiteral() const {
auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle);
TF_RETURN_IF_ERROR(root.status());
tensorflow::ClientSession session(root, "local");
tensorflow::ClientSession session(root, session_target_);
tensorflow::ClientSession::FeedType inputs;
inputs.insert({allocation_handle, handle()});
std::vector<tensorflow::Tensor> outputs;
@ -357,8 +359,11 @@ static StatusOr<Shape> GetReturnValueShape(const XlaComputation& computation) {
}
CompiledXrtComputation::CompiledXrtComputation(
const ProgramShape& program_shape, int64 handle)
: program_shape_(program_shape), handle_(handle) {}
const ProgramShape& program_shape, int64 handle,
const string& session_target)
: program_shape_(program_shape),
handle_(handle),
session_target_(session_target) {}
CompiledXrtComputation::~CompiledXrtComputation() {
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
@ -371,7 +376,7 @@ CompiledXrtComputation::~CompiledXrtComputation() {
return;
}
tensorflow::ClientSession session(root, "local");
tensorflow::ClientSession session(root, session_target_);
tensorflow::ClientSession::FeedType inputs;
inputs.insert({computation_handle, handle()});
std::vector<tensorflow::Tensor> outputs;
@ -407,7 +412,7 @@ StatusOr<XrtAllocation*> CompiledXrtComputation::Execute(
e.set_release_input_handles(false);
e.set_release_compilation_handle(false);
tensorflow::ClientSession session(root, "local");
tensorflow::ClientSession session(root, session_target_);
tensorflow::ClientSession::FeedType inputs;
for (int i = 0; i < arguments.size(); ++i) {
inputs.insert({arguments[i], argument_handles[i]->handle()});
@ -418,7 +423,7 @@ StatusOr<XrtAllocation*> CompiledXrtComputation::Execute(
TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs));
int64 output = outputs[0].scalar<int64>()();
return new XrtAllocation(output, program_shape().result());
return new XrtAllocation(output, program_shape().result(), session_target_);
}
const ProgramShape& CompiledXrtComputation::program_shape() const {
@ -451,7 +456,7 @@ StatusOr<CompiledLocalComputation*> LocalComputation::Compile(
}
StatusOr<CompiledXrtComputation*> LocalComputation::CompileForXrt(
const std::vector<Shape>& argument_shapes) {
const std::vector<Shape>& argument_shapes, const string& session_target) {
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING);
auto compile = tensorflow::ops::XRTCompile(root, program);
@ -468,7 +473,7 @@ StatusOr<CompiledXrtComputation*> LocalComputation::CompileForXrt(
auto snapshot = computation().Snapshot().ValueOrDie();
*c.mutable_hlo_snapshot() = *snapshot;
tensorflow::ClientSession session(root, "local");
tensorflow::ClientSession session(root, session_target);
tensorflow::ClientSession::FeedType inputs;
inputs.insert({program, c.SerializeAsString()});
std::vector<tensorflow::Tensor> outputs;
@ -477,7 +482,7 @@ StatusOr<CompiledXrtComputation*> LocalComputation::CompileForXrt(
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation().GetProgramShape());
int64 handle = outputs[0].scalar<int64>()();
return new CompiledXrtComputation(program_shape, handle);
return new CompiledXrtComputation(program_shape, handle, session_target);
}
const XlaComputation& LocalComputation::computation() const {
@ -929,7 +934,7 @@ StatusOr<LocalShapedBufferTuple*> DestructureLocalShapedBufferTuple(
}
StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
XrtAllocation* allocation) {
XrtAllocation* allocation, const string& session_target) {
const Shape& tuple_shape = allocation->shape();
if (!ShapeUtil::IsTuple(tuple_shape)) {
@ -945,7 +950,7 @@ StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index);
TF_RETURN_IF_ERROR(root.status());
tensorflow::ClientSession session(root, "local");
tensorflow::ClientSession session(root, session_target);
tensorflow::ClientSession::FeedType inputs;
std::vector<XrtAllocation*> results;
for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
@ -964,7 +969,8 @@ StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
const int64 subtuple_handle = outputs[0].scalar<int64>()();
const Shape& subtuple_shape =
ShapeUtil::GetTupleElementShape(tuple_shape, i);
results.push_back(new XrtAllocation(subtuple_handle, subtuple_shape));
results.push_back(
new XrtAllocation(subtuple_handle, subtuple_shape, session_target));
}
return new XrtAllocationTuple(std::move(results));
}

View File

@ -16,6 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
#include <string>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
@ -110,17 +113,22 @@ StatusOr<LocalShapedBufferTuple*> DestructureLocalShapedBufferTuple(
// graph, and an XLA shape to track the referent's shape.
class XrtAllocation {
public:
static StatusOr<XrtAllocation*> FromLiteral(const Literal& argument);
// Accepts a `session_target` argument, used in constructing the
// `tensorflow::ClientSession` instance in which allocation and deallocation
// graphs are run.
static StatusOr<XrtAllocation*> FromLiteral(const Literal& argument,
const string& session_target);
XrtAllocation(int64 handle, Shape shape);
XrtAllocation(int64 handle, Shape shape, const string& session_target);
~XrtAllocation();
StatusOr<Literal> ToLiteral() const;
const Shape& shape() const;
const int64 handle() const;
private:
int64 handle_;
Shape shape_;
const int64 handle_;
const Shape shape_;
const string session_target_;
};
// Result of a tuple destructuring operation on an XrtAllocation.
@ -145,8 +153,12 @@ class XrtAllocationTuple {
// Destructures a tuple-valued XrtAllocation into its constitutent elements
// in XrtAllocationTuple form.
//
// Accepts a `session_target` argument, used in constructing the
// `tensorflow::ClientSession` instance in which the sub-tupling graph is run,
// and passed along in constructing each constituent XrtAllocation.
StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
XrtAllocation* allocation);
XrtAllocation* allocation, const string& session_target);
// Represents a compiled computation that can be executed given handles to
// device-allocated literals. Specifically, wraps an XLA LocalExecutable.
@ -165,7 +177,10 @@ class CompiledLocalComputation {
// device-allocated literals. Specifically, wraps an XRT computation handle.
class CompiledXrtComputation {
public:
CompiledXrtComputation(const ProgramShape& program_shape, int64 handle);
// Accepts a `session_target` argument, used in constructing the
// `tensorflow::ClientSession` instance in which the execution graph is run.
CompiledXrtComputation(const ProgramShape& program_shape, int64 handle,
const string& session_target);
~CompiledXrtComputation();
StatusOr<XrtAllocation*> Execute(
@ -175,8 +190,9 @@ class CompiledXrtComputation {
int64 handle() const;
private:
ProgramShape program_shape_;
int64 handle_;
const ProgramShape program_shape_;
const int64 handle_;
const string session_target_;
};
// Wraps a XlaComputation produced by a LocalComputationBuilder. The
@ -191,8 +207,10 @@ class LocalComputation {
const std::vector<Shape>& argument_shapes,
const ExecutableBuildOptions* build_options);
// Accepts a `session_target` argument, used in constructing the
// `tensorflow::ClientSession` instance in which the compilation graph is run.
StatusOr<CompiledXrtComputation*> CompileForXrt(
const std::vector<Shape>& argument_shapes);
const std::vector<Shape>& argument_shapes, const string& session_target);
const XlaComputation& computation() const;

View File

@ -451,6 +451,10 @@ tensorflow::ImportNumpy();
// Shape
%typemap(out) const Shape& {
$result = numpy::PyShapeInfoFromXlaShape(*$1);
}
%typemap(out) StatusOr<Shape> {
if ($1.ok()) {
$result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie());
@ -980,6 +984,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalShapedBuffer;
%unignore xla::swig::LocalShapedBuffer::FromLiteral;
%unignore xla::swig::LocalShapedBuffer::ToLiteral;
%unignore xla::swig::LocalShapedBuffer::shape;
%unignore xla::swig::LocalShapedBufferTuple;
%unignore xla::swig::LocalShapedBufferTuple::Release;
%unignore xla::swig::LocalShapedBufferTuple::size;

View File

@ -51,6 +51,10 @@ class BackendType(enum.Enum):
XRT = 2
BackendSpec = collections.namedtuple('Backend', ('backend_type', 'target'))
XLA_LOCAL_BACKEND = BackendSpec(BackendType.XLA_LOCAL, 'local')
def OpMetadataToProto(pyobj):
proto = xla_data_pb2.OpMetadata()
for field in _OP_METADATA_FIELDS:
@ -211,17 +215,17 @@ class LocalBuffer(object):
def __init__(self, c_buffer, backend):
self.c_buffer = c_buffer
self._backend = backend
if backend == BackendType.XRT:
if backend.backend_type == BackendType.XRT:
self._delete = c_api.DeleteXrtAllocation
else:
self._delete = c_api.DeleteLocalShapedBuffer
@staticmethod
def from_pyval(pyval, backend=BackendType.XLA_LOCAL):
def from_pyval(pyval, backend=XLA_LOCAL_BACKEND):
"""Allocate and copy to XLA the given python value."""
pyval = require_numpy_array_layout(pyval)
if backend == BackendType.XRT:
cbuf = c_api.XrtAllocation.FromLiteral(pyval)
if backend.backend_type == BackendType.XRT:
cbuf = c_api.XrtAllocation.FromLiteral(pyval, backend.target)
else:
cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None)
return LocalBuffer(cbuf, backend)
@ -229,6 +233,9 @@ class LocalBuffer(object):
def to_py(self):
return self.c_buffer.ToLiteral()
def shape(self):
return _wrap_shape(self.c_buffer.shape())
def delete(self):
if self.c_buffer is not None:
self._delete(self.c_buffer)
@ -237,8 +244,9 @@ class LocalBuffer(object):
def destructure(self):
"""Assuming a tuple buffer, unpack it into constituent tuple elements."""
assert self.c_buffer is not None
if self._backend == BackendType.XRT:
result = c_api.DestructureXrtAllocationTuple(self.c_buffer)
if self._backend.backend_type == BackendType.XRT:
result = c_api.DestructureXrtAllocationTuple(self.c_buffer,
self._backend.target)
else:
result = c_api.DestructureLocalShapedBufferTuple(self.c_buffer)
self.delete()
@ -467,14 +475,14 @@ class LocalComputation(object):
ComputationBuilder methods.
"""
def __init__(self, c_computation, is_compiled, backend=BackendType.XLA_LOCAL):
def __init__(self, c_computation, is_compiled, backend=XLA_LOCAL_BACKEND):
self._c_computation = c_computation
self._backend = backend
self._is_compiled = is_compiled
# Ensure a reference to C-based destructor for use in __del__.
if is_compiled:
if backend == BackendType.XRT:
if backend.backend_type == BackendType.XRT:
assert isinstance(c_computation, c_api.CompiledXrtComputation)
self._delete = c_api.DeleteCompiledXrtComputation
else:
@ -535,8 +543,8 @@ class LocalComputation(object):
compile_options = compile_options or CompileOptions()
compile_options.result_shape = result_shape
if self._backend == BackendType.XRT:
c = self.computation.CompileForXrt(argument_shapes)
if self._backend.backend_type == BackendType.XRT:
c = self.computation.CompileForXrt(argument_shapes, self._backend.target)
else:
c = self.computation.Compile(argument_shapes, compile_options)
return LocalComputation(c, is_compiled=True, backend=self._backend)
@ -590,7 +598,7 @@ class ComputationBuilder(object):
self._client = c_api.LocalComputationBuilder(name.encode('utf8'))
self._parameter_numbering = itertools.count()
def Build(self, root=None, backend=BackendType.XLA_LOCAL):
def Build(self, root=None, backend=XLA_LOCAL_BACKEND):
if root is not None:
return LocalComputation(
self._client.BuildWithRoot(root), is_compiled=False, backend=backend)

View File

@ -439,6 +439,13 @@ class LocalBufferTest(LocalComputationTest):
np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0])
np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1])
def testShape(self):
pyval = np.array([[1., 2.]], np.float32)
local_buffer = xla_client.LocalBuffer.from_pyval(pyval)
xla_shape = local_buffer.shape()
self.assertEqual(xla_shape.dimensions(), (1, 2,))
self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
class SingleOpTest(LocalComputationTest):
"""Tests for single ops.

View File

@ -323,7 +323,6 @@ cc_library(
":hlo_casting_utils",
":hlo_module_config",
":hlo_proto",
":hlo_reachability",
":name_uniquer",
"//tensorflow/compiler/xla:array",
"//tensorflow/compiler/xla:literal",
@ -402,6 +401,7 @@ cc_library(
srcs = ["hlo_reachability.cc"],
hdrs = ["hlo_reachability.h"],
deps = [
":hlo",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
@ -1103,6 +1103,7 @@ cc_library(
":hlo",
":hlo_dataflow_analysis",
":hlo_proto",
":hlo_reachability",
":hlo_value",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -1362,6 +1363,7 @@ cc_library(
":fusion_queue",
":hlo",
":hlo_pass",
":hlo_reachability",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
@ -1387,6 +1389,7 @@ cc_library(
srcs = ["multi_output_fusion.cc"],
hdrs = ["multi_output_fusion.h"],
deps = [
":hlo_reachability",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:hlo",
@ -3241,6 +3244,7 @@ cc_library(
":hlo_profile_printer_data",
":human_readable_profile_builder",
"//tensorflow/compiler/xla:types",
"@com_google_absl//absl/strings",
],
)
@ -3365,6 +3369,7 @@ cc_library(
":bfloat16_normalization",
":defuser",
":hlo",
":hlo_memory_scheduler",
":hlo_pass",
":hlo_pass_pipeline",
":implicit_broadcast_remover",
@ -3448,6 +3453,7 @@ tf_cc_test(
":hlo_casting_utils",
":hlo_matchers",
":hlo_parser",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/core:lib",
"//tensorflow/core:test",

View File

@ -306,6 +306,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Tries to use a kDot in place of the given convolution.
StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
// Tries to simplify a slice(pad(...)) where the result of the slice is a
// scalar.
StatusOr<bool> TrySimplifySliceOfPad(HloInstruction* slice);
// Current HloComputation instance the AlgebraicSimplifierVisitor is
// traversing.
HloComputation* computation_;
@ -1822,6 +1826,62 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
return Status::OK();
}
StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifySliceOfPad(
HloInstruction* slice) {
// Only try to do this for effective scalars. We could do the same for slicing
// out larger pieces of padding (replacing with a broadcast of the padding
// value), but this is probably not worth it.
if (!ShapeUtil::IsEffectiveScalar(slice->shape()) ||
slice->operand(0)->opcode() != HloOpcode::kPad) {
return false;
}
VLOG(10) << "Trying to simplify scalar slice of pad";
// Check there's no internal padding. Again, we could handle that too, since
// everything is statically known, but it's not worth it.
auto pad = Cast<HloPadInstruction>(slice->mutable_operand(0));
auto padding_config = pad->padding_config();
int64 rank = padding_config.dimensions_size();
if (HasInteriorPadding(padding_config)) {
VLOG(10) << "Not folding scalar slice of pad, pad has interior padding";
return false;
}
// Check whether the scalar we're slicing out falls into the padding.
bool in_padding = [&]() {
for (int64 i = 0; i < rank; ++i) {
int64 start = slice->slice_starts(i);
int64 low = padding_config.dimensions(i).edge_padding_low();
int64 data = pad->operand(0)->shape().dimensions(i);
if (start >= low && start < low + data) {
return false;
}
}
return true;
}();
if (in_padding) {
VLOG(10) << "Folding scalar slice of pad into padding value";
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
slice, HloInstruction::CreateReshape(slice->shape(),
pad->mutable_padding_value())));
return true;
} else {
// We already know the output of the slice is scalar. If the padded
// value is scalar, and it's not in the padding, then it's exactly the
// output value.
bool replaced =
ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0));
if (replaced) {
VLOG(10) << "Folding scalar slice of pad into padded value";
} else {
VLOG(10) << "Not folding scalar slice of pad into padded value as they "
"have different shapes.";
}
return replaced;
}
}
Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
// Delete no-op slices, i.e. where shape = operand shape.
if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) {
@ -1846,6 +1906,12 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
slice->shape(), operand_slice->mutable_operand(0),
new_slice_starts, new_slice_limits, slice->slice_strides()));
}
TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifySliceOfPad(slice));
if (replaced) {
return Status::OK();
}
return Status::OK();
}

View File

@ -3163,6 +3163,92 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) {
EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
}
TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) {
const char* hlo_string = R"(
HloModule module
ENTRY test {
param = f32[3,4] parameter(0)
constant = f32[] constant(0.0)
pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[2:3],[0:1]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
bitcasting_callback());
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Reshape(op::Constant()));
}
TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
const char* hlo_string = R"(
HloModule module
ENTRY test {
param = f32[3,4] parameter(0)
constant = f32[] constant(0.0)
pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[6:7],[9:10]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
bitcasting_callback());
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Reshape(op::Constant()));
}
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
const char* hlo_string = R"(
HloModule module
ENTRY test {
param = f32[3,4] parameter(0)
constant = f32[] constant(0.0)
pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
bitcasting_callback());
EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
}
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) {
const char* hlo_string = R"(
HloModule module
ENTRY test {
param = f32[1,1] parameter(0)
constant = f32[] constant(0.0)
pad = f32[8,10] pad(f32[1,1] param, f32[] constant), padding=3_4x4_5
ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[3:4],[4:5]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
bitcasting_callback());
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Parameter());
}
struct PadReduceWindowEffectiveBroadcastCase {
std::vector<int64> input_spatials;
std::vector<int64> symmetric_pad_spatials;

View File

@ -151,15 +151,10 @@ Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions(
Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
// Do not fold BF16 conversions for instructions related to tuples, entry and
// exit of a computation, fusion, convert, and control flow.
// exit of a computation, fusion, convert, side-effecting instructions and
// control flow.
if (hlo->opcode() == HloOpcode::kTuple || //
hlo->opcode() == HloOpcode::kGetTupleElement || //
hlo->opcode() == HloOpcode::kInfeed || //
hlo->opcode() == HloOpcode::kOutfeed || //
hlo->opcode() == HloOpcode::kSend || //
hlo->opcode() == HloOpcode::kSendDone || //
hlo->opcode() == HloOpcode::kRecv || //
hlo->opcode() == HloOpcode::kRecvDone || //
hlo->opcode() == HloOpcode::kConstant || //
hlo->opcode() == HloOpcode::kParameter || //
hlo->opcode() == HloOpcode::kFusion || //
@ -167,7 +162,8 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
hlo->opcode() == HloOpcode::kCall || //
hlo->opcode() == HloOpcode::kCustomCall || //
hlo->opcode() == HloOpcode::kWhile || //
hlo->opcode() == HloOpcode::kConditional) {
hlo->opcode() == HloOpcode::kConditional || //
hlo->HasSideEffectNoRecurse()) {
return Status::OK();
}
if (hlo == computation_->root_instruction() &&
@ -182,6 +178,10 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum(
HloInstruction* crs) {
if (crs->IsCrossModuleAllReduce()) {
// Cross-module all-reduce has side effect.
return Status::OK();
}
// First use DefaultAction() to handle the operands. It can't handle
// tuple-shaped output.
TF_RETURN_IF_ERROR(DefaultAction(crs));

View File

@ -346,11 +346,9 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
// Do not change instructions related to entry and exit of a computation,
// tuples, fusion, convert, and control flow.
// tuples, fusion, convert, side-effecting instructions, and control flow.
if (hlo->opcode() == HloOpcode::kTuple || //
hlo->opcode() == HloOpcode::kGetTupleElement || //
hlo->opcode() == HloOpcode::kInfeed || //
hlo->opcode() == HloOpcode::kOutfeed || //
hlo->opcode() == HloOpcode::kConstant || //
hlo->opcode() == HloOpcode::kParameter || //
hlo->opcode() == HloOpcode::kFusion || //
@ -358,7 +356,8 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
hlo->opcode() == HloOpcode::kCall || //
hlo->opcode() == HloOpcode::kCustomCall || //
hlo->opcode() == HloOpcode::kWhile || //
hlo->opcode() == HloOpcode::kConditional) {
hlo->opcode() == HloOpcode::kConditional || //
hlo->HasSideEffectNoRecurse()) {
return Status::OK();
}
// TODO(b/112040122): Correctly normalize variadic reduce.

View File

@ -236,6 +236,10 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
// the end of the BFloat16Propagation pass.
continue;
}
if (use.instruction->HasSideEffectNoRecurse()) {
// Keep side-effecting instruction's operands unchanged.
return false;
}
// Any visited user that can accept BF16 has already been updated if
// necessary, e.g., the output has been changed to BF16 if it propagates
// precision, or a called computation's parameters have been changed to
@ -329,22 +333,6 @@ void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
return;
}
// Do not change precision for instructions related to entry and exit of a
// computation, and control flow, because this pass might break the interfaces
// or assumptions for them.
if (hlo->opcode() == HloOpcode::kInfeed || //
hlo->opcode() == HloOpcode::kOutfeed || //
hlo->opcode() == HloOpcode::kSend || //
hlo->opcode() == HloOpcode::kSendDone || //
hlo->opcode() == HloOpcode::kRecv || //
hlo->opcode() == HloOpcode::kRecvDone || //
hlo->opcode() == HloOpcode::kCustomCall || //
hlo->opcode() == HloOpcode::kCall || //
hlo->opcode() == HloOpcode::kConditional || //
(hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
return;
}
// Prevent root instructions from having their output modified by recording
// all F32 output values as needing to stay as F32.
CHECK(hlo->parent() != nullptr);
@ -366,6 +354,17 @@ void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
return;
}
// Do not change precision for instructions related to entry and exit of a
// computation, side-effecting instructions, and control flow, because this
// pass might break the interfaces or assumptions for them.
if (hlo->opcode() == HloOpcode::kCustomCall || //
hlo->opcode() == HloOpcode::kCall || //
hlo->opcode() == HloOpcode::kConditional || //
hlo->HasSideEffectNoRecurse() || //
(hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
return;
}
if (!ContainsKey(consider_using_bfloat16_, hlo)) {
return;
}

View File

@ -136,6 +136,40 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) {
EXPECT_FALSE(OutputsBF16(c));
}
// Tests that side-effecting all-reduce should not be changed.
TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) {
auto module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
HloInstruction* b =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
auto rb = HloComputation::Builder(TestName());
rb.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd,
rb.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")),
rb.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"))));
auto reduction = module->AddEmbeddedComputation(rb.Build());
HloInstruction* all_reduce =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction,
/*replica_groups=*/{}, /*barrier=*/"", /*all_reduce_id=*/1));
HloInstruction* gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, all_reduce, 0));
HloInstruction* gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, all_reduce, 1));
HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1));
HloInstruction* root = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_FALSE(PropagatePrecision(module.get()));
EXPECT_EQ(computation->root_instruction(), root);
}
// Tests that if a constant is converted to BF16 then its literal must also be
// converted.
TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {

View File

@ -502,8 +502,8 @@ Status CreateHloProfilingArtifacts(
HloCostAnalysis cost_analysis(shape_size_bytes);
TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis));
*hlo_profile_printer_data =
CreateHloProfilePrinterData(**hlo_profile_index_map, cost_analysis);
*hlo_profile_printer_data = CreateHloProfilePrinterData(
**hlo_profile_index_map, cost_analysis, entry_computation.name());
*computation_to_profile_idx =
(*hlo_profile_index_map)->computation_to_profile_idx();

View File

@ -1546,10 +1546,8 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0};
}
// Return whether the given shape is a matrix with no padding.
static bool IsRank2WithNoPadding(const Shape& shape) {
return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);
}
// Return whether the given shape is rank 2.
static bool IsRank2(const Shape& shape) { return ShapeUtil::Rank(shape) == 2; }
// In a gemm operation where output = lhs * rhs, check whether the given shapes
// are valid for the operation.
@ -1565,8 +1563,7 @@ static bool AreValidGemmShapes(
return false;
}
if (!(IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) &&
IsRank2WithNoPadding(output_shape))) {
if (!(IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape))) {
return false;
}

View File

@ -2206,16 +2206,16 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
// Delegate to common implementation of fused in-place dynamic-update-slice.
auto operands = GetIrArraysForOperandsOf(fusion);
return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
fusion, operands, GetIrArrayFor(fusion), &elemental_emitter, &b_);
fusion, GetGeneratorForOperandIrArrays(fusion), GetIrArrayFor(fusion),
&elemental_emitter, &b_);
} else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) {
VLOG(3) << "HandleFusion kLoop";
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
auto operands = GetIrArraysForOperandsOf(fusion);
FusedIrEmitter fused_emitter(operands, &elemental_emitter);
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
&elemental_emitter);
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator());
@ -2415,14 +2415,8 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
*failure_reason = "operand has mismatching layouts";
return false;
}
if (LayoutUtil::IsPadded(op->shape())) {
*failure_reason = "operand has padded layout";
return false;
}
}
CHECK(!LayoutUtil::IsPadded(concatenate->shape()));
// We split the dimensions into three categories: the dimension over which we
// are concatenating (concat_dim), the dimensions that are minor to it
// (inner_dims) and the dimensions that are major to it (outer_dims).

View File

@ -59,6 +59,9 @@ namespace cpu {
class IrEmitter : public DfsHloVisitorWithDefault,
public IrBuilderMixin<IrEmitter> {
public:
using GeneratorForOperandIrArrays =
std::function<std::vector<llvm_ir::IrArray>()>;
// Create a new LLVM IR emitter.
//
// hlo_module: the HLO module we are emitting IR for.
@ -208,6 +211,11 @@ class IrEmitter : public DfsHloVisitorWithDefault,
std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf(
const HloInstruction* hlo);
GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays(
HloInstruction* unnested_hlo) {
return [=]() { return GetIrArraysForOperandsOf(unnested_hlo); };
}
// Augments IrArray with aliasing information.
void AddAliasingInformationToIrArray(const HloInstruction& hlo,
llvm_ir::IrArray* array) {

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
#include "tensorflow/compiler/xla/service/defuser.h"
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h"
namespace xla {
@ -45,6 +46,7 @@ class ControlDepRemover : public HloModulePass {
Despecializer::Despecializer() : pipeline_("despecializer") {
// TODO(b/70588125): Also deal with window reversal in a fast way.
pipeline_.AddPass<HloDescheduler>();
pipeline_.AddPass<ControlDepRemover>();
pipeline_.AddPass<Defuser>();
pipeline_.AddPass<ImplicitBroadcastRemover>();

View File

@ -484,7 +484,7 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
}
// Extracted from //learning/brain/google/xla/benchmarks/resnet.py
// Extracted from Resnet-50.
//
// For simplicity, we focus on the column dimension and ignore other dimensions.
// We use [?] to represent the shape instead of the content.

View File

@ -126,9 +126,9 @@ Status RunCudnnConvImpl(CudnnConvParams params,
int64 feature_group_count = params.feature_group_count;
AlgorithmConfig algorithm = params.algorithm;
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm()->algo_id();
VLOG(3) << "tensor_ops_enabled: "
<< algorithm.algorithm().tensor_ops_enabled();
<< algorithm.algorithm()->tensor_ops_enabled();
VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind);
VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(input_shape);
VLOG(3) << "filter shape: " << ShapeUtil::HumanStringWithLayout(filter_shape);
@ -302,8 +302,8 @@ Status RunCudnnConvImpl(CudnnConvParams params,
if (!stream->ok()) {
return InternalError(
"Unable to launch convolution with type %s and algorithm (%d, %d)",
CudnnConvKindToString(kind), algorithm.algorithm().algo_id(),
algorithm.algorithm_no_scratch().algo_id());
CudnnConvKindToString(kind), algorithm.algorithm()->algo_id(),
algorithm.algorithm_no_scratch()->algo_id());
}
return Status::OK();
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <functional>
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -51,7 +52,8 @@ struct MatrixDescriptor {
// rhs_matrix, and stores the result to output_matrix.
template <typename Element>
bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, double alpha, se::Stream* stream) {
MatrixDescriptor output_matrix, double alpha, double beta,
se::Stream* stream) {
DCHECK(!output_matrix.transpose);
const int64 batch_size = lhs_matrix.batch_size;
@ -73,7 +75,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
lhs_transpose, rhs_transpose, output_matrix.num_rows,
output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
/*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0,
/*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta,
&output_data, /*leading dim of output=*/output_matrix.num_rows)
.ok();
}
@ -88,7 +90,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
/*alpha=*/alpha, lhs_data,
/*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data,
/*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride,
/*beta=*/0.0, &output_data,
/*beta=*/beta, &output_data,
/*leading dim of output=*/output_matrix.num_rows, output_stride,
batch_size)
.ok();
@ -112,6 +114,7 @@ template <typename Element>
bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, double alpha,
double beta,
se::blas::ComputationType computation_type,
se::blas::AlgorithmType algorithm, se::Stream* stream,
se::blas::ProfileResult* output_profile_result) {
@ -138,7 +141,7 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
/*alpha=*/static_cast<Element>(alpha), lhs_data,
/*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
/*leading dim of RHS=*/rhs_matrix.num_rows,
/*beta=*/static_cast<Element>(0.0f), &output_data,
/*beta=*/static_cast<Element>(beta), &output_data,
/*leading dim of output=*/output_matrix.num_rows, computation_type,
algorithm, output_profile_result)
.ok();
@ -153,7 +156,7 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
template <typename Element>
StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, double alpha,
MatrixDescriptor output_matrix, double alpha, double beta,
se::blas::ComputationType computation_type, se::Stream* stream) {
std::vector<se::blas::AlgorithmType> algorithms;
CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms));
@ -166,7 +169,7 @@ StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
// non-null ProfileResult, DoGemmWithAlgorithm should always return true,
// and the actual success-ness is returned in ProfileResult::is_valid.
CHECK(DoGemmWithAlgorithm<Element>(lhs_matrix, rhs_matrix, output_matrix,
alpha, computation_type, algorithm,
alpha, beta, computation_type, algorithm,
stream, &profile_result));
if (profile_result.is_valid()) {
@ -263,8 +266,9 @@ DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) {
}
CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion);
CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput);
CHECK_EQ(hlo_instruction.fused_expression_root()->opcode(),
HloOpcode::kMultiply);
CHECK(hlo_instruction.fused_expression_root()->opcode() == HloOpcode::kAdd ||
hlo_instruction.fused_expression_root()->opcode() ==
HloOpcode::kMultiply);
// Try to find the dot inside the output fusion node.
const HloInstruction* dot =
hlo_instruction.fused_expression_root()->operand(0);
@ -282,8 +286,9 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
const BufferAllocation::Slice& rhs_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& lhs_shape, const Shape& rhs_shape,
const Shape& output_shape, double alpha,
const HloInstruction* hlo_instruction)
const Shape& output_shape, double alpha, double beta,
const HloInstruction* hlo_instruction,
bool implements_whole_instruction)
: Thunk(Kind::kGemm, hlo_instruction),
lhs_buffer_(lhs_buffer),
rhs_buffer_(rhs_buffer),
@ -291,7 +296,9 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
lhs_shape_(lhs_shape),
rhs_shape_(rhs_shape),
output_shape_(output_shape),
alpha_(alpha) {}
alpha_(alpha),
beta_(beta),
implements_whole_instruction_(implements_whole_instruction) {}
Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
@ -386,7 +393,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
// TODO(b/112111608): Implement auto tune for batched gemm.
if (batch_size != 1) {
return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
alpha_, stream);
alpha_, beta_, stream);
}
auto thunk_name = [&] {
@ -398,9 +405,27 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
auto autotune_it = autotune_results_.find(device_name);
if (autotune_it == autotune_results_.end()) {
VLOG(3) << "Starting autotune of GemmThunk " << thunk_name();
StatusOr<se::blas::AlgorithmType> best_algorithm =
GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
alpha_, computation_type, stream);
// If the output buffer already contains a bias then autotune into a
// scratch buffer. This avoids overwriting the bias buffer. The scratch
// buffer may contain arbitrary garbage values.
se::DeviceMemoryBase scratch_data = output_data;
std::unique_ptr<se::TemporaryDeviceMemory<char>> scratch_mem;
if (beta_ != 0.0) {
auto temp_status = stream->AllocateTemporaryArray<char>(
ShapeUtil::ByteSizeOf(output_shape_));
if (!temp_status.ok()) {
return false;
}
scratch_mem = std::move(temp_status).ValueOrDie();
scratch_data = scratch_mem->device_memory();
}
const MatrixDescriptor scratch_descriptor(
scratch_data, false, output_num_cols, output_num_rows, batch_size);
StatusOr<se::blas::AlgorithmType> best_algorithm = GetGemmAutotuneFn(
element_type)(lhs_matrix, rhs_matrix, scratch_descriptor, alpha_,
beta_, computation_type, stream);
autotune_it =
autotune_results_.insert({device_name, best_algorithm}).first;
@ -421,18 +446,19 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
VLOG(2) << "Using algorithm " << algorithm
<< " chosen by autotuning on GemmThunk " << thunk_name();
return GetGemmWithAlgorithmFn(element_type)(
lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type,
algorithm, stream,
lhs_matrix, rhs_matrix, output_matrix, alpha_, beta_,
computation_type, algorithm, stream,
/*output_profile_result=*/nullptr);
}
// Autotune will fail when CUDA 8 and GPU sm_50 or older are used.
// Use the older Gemm API in this case.
return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
alpha_, stream);
alpha_, beta_, stream);
};
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
auto op_profiler = profiler->MakeScopedInstructionProfiler(
implements_whole_instruction_ ? hlo_instruction() : nullptr);
bool launch_ok;
if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) {
launch_ok = launch(lhs_descriptor, rhs_descriptor,

View File

@ -41,8 +41,9 @@ class GemmThunk : public Thunk {
const BufferAllocation::Slice& rhs_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& lhs_shape, const Shape& rhs_shape,
const Shape& output_shape, double alpha,
const HloInstruction* hlo_instruction);
const Shape& output_shape, double alpha, double beta,
const HloInstruction* hlo_instruction,
bool implements_whole_instruction);
GemmThunk(const GemmThunk&) = delete;
GemmThunk& operator=(const GemmThunk&) = delete;
@ -70,6 +71,9 @@ class GemmThunk : public Thunk {
const Shape output_shape_;
const double alpha_;
const double beta_;
const bool implements_whole_instruction_;
// Maps device names (StreamExecutor::DeviceDescription::name()) to autotune
// results. The map's value is the best algorithm we've found for this thunk

View File

@ -124,7 +124,8 @@ GpuHloOrdering::GpuHloOrdering(
for (auto* computation : module->computations()) {
if (computation != module->entry_computation() &&
!computation->IsFusionComputation()) {
predecessors_.emplace(computation, computation->ComputeReachability());
predecessors_.emplace(computation,
HloReachabilityMap::Build(computation));
}
}
}

View File

@ -179,6 +179,10 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
IsIEEEFloatingPointScalarConstant(alpha->operand(0))) {
return true;
}
} else if (consumer->operand_count() == 2 &&
consumer->opcode() == HloOpcode::kAdd) {
// Fuse a bias add into the output of the dot.
return true;
}
}

View File

@ -331,6 +331,33 @@ TEST_F(InstructionFusionTest, DotOutputFusion) {
op::Broadcast(op::Constant())));
}
TEST_F(InstructionFusionTest, DotOutputFusionBiasAdd) {
auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
alpha = f32[] constant(3)
broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={}
p0 = f32[4,3]{1,0} parameter(0)
p1 = f32[4,3]{1,0} parameter(1)
p2 = f32[4,4]{1,0} parameter(2)
transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0}
dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT add = f32[4,4] add(dot, p2)
})")
.ValueOrDie();
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput);
EXPECT_THAT(root->fused_expression_root(),
op::Add(op::Dot(op::Parameter(), op::Transpose(op::Parameter())),
op::Parameter()));
}
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
// duplicated and fused into both reduces.
TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {

View File

@ -38,10 +38,9 @@ namespace gpu {
namespace {
// Return whether the given shape is a matrix with no padding.
bool IsRank2WithNoPadding(const Shape& shape, int64 batch_dimensions_size) {
return ShapeUtil::Rank(shape) == batch_dimensions_size + 2 &&
!LayoutUtil::IsPadded(shape);
// Return whether the given shape is rank 2 excluding the batch dimensions.
bool IsRank2(const Shape& shape, int64 batch_dimensions_size) {
return ShapeUtil::Rank(shape) == batch_dimensions_size + 2;
}
// In a gemm operation where output = lhs * rhs, check whether the given shapes
@ -56,10 +55,9 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
bool type_is_allowed =
(output_primitive_type == F16 || output_primitive_type == F32 ||
output_primitive_type == F64 || output_primitive_type == C64);
return type_is_allowed &&
IsRank2WithNoPadding(lhs_shape, batch_dimensions_size) &&
IsRank2WithNoPadding(rhs_shape, batch_dimensions_size) &&
IsRank2WithNoPadding(output_shape, batch_dimensions_size) &&
return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) &&
IsRank2(rhs_shape, batch_dimensions_size) &&
IsRank2(output_shape, batch_dimensions_size) &&
!ShapeUtil::IsZeroElementArray(lhs_shape) &&
!ShapeUtil::IsZeroElementArray(rhs_shape);
}
@ -93,7 +91,8 @@ bool ImplementedAsGemm(const HloInstruction& hlo) {
if (hlo.opcode() == HloOpcode::kFusion &&
hlo.fusion_kind() == HloInstruction::FusionKind::kOutput &&
hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply) {
(hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply ||
hlo.fused_expression_root()->opcode() == HloOpcode::kAdd)) {
// Try to find the dot inside the output fusion node.
const HloInstruction* dot = hlo.fused_expression_root()->operand(0);
if (dot->opcode() != HloOpcode::kDot) {

View File

@ -697,15 +697,11 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
Status IrEmitter::HandleFusion(HloInstruction* fusion) {
// kFusion for library calls should be handled by
// IrEmitterUnnested::HandleFusion.
CHECK(HloInstruction::FusionKind::kLoop == fusion->fusion_kind());
std::vector<llvm_ir::IrArray> parameter_arrays;
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
}
CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
GetNestedComputer());
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
&elemental_emitter);
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator());

Some files were not shown because too many files have changed in this diff Show More