From 27412f3b64ad09131ce330a0b91938af1931d515 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 31 Oct 2017 20:49:18 -0700 Subject: [PATCH] Add compiler/tf2xla/sharding_util.h with utilities for getting the core device from a Node. PiperOrigin-RevId: 174133602 --- tensorflow/compiler/tf2xla/BUILD | 31 ++++++++ .../compiler/tf2xla/kernels/retval_op.cc | 8 ++- tensorflow/compiler/tf2xla/sharding_util.cc | 72 +++++++++++++++++++ tensorflow/compiler/tf2xla/sharding_util.h | 44 ++++++++++++ .../compiler/tf2xla/sharding_util_test.cc | 58 +++++++++++++++ .../compiler/tf2xla/xla_compilation_device.cc | 23 +++--- tensorflow/compiler/tf2xla/xla_compiler.cc | 67 ++++++++++++++--- tensorflow/compiler/tf2xla/xla_compiler.h | 3 +- .../compiler/xla/client/computation_builder.h | 39 ++++++++-- tensorflow/contrib/tpu/python/tpu/tpu.py | 7 +- 10 files changed, 321 insertions(+), 31 deletions(-) create mode 100644 tensorflow/compiler/tf2xla/sharding_util.cc create mode 100644 tensorflow/compiler/tf2xla/sharding_util.h create mode 100644 tensorflow/compiler/tf2xla/sharding_util_test.cc diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 3c94bcafc1d..d4c6cb56b06 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -123,6 +123,7 @@ cc_library( ":const_analysis", ":dump_graph", ":functionalize_control_flow", + ":sharding_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -169,6 +170,36 @@ cc_library( ], ) +cc_library( + name = "sharding_util", + srcs = ["sharding_util.cc"], + hdrs = ["sharding_util.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "sharding_util_test", + srcs = ["sharding_util_test.cc"], + deps = [ + ":sharding_util", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + # Internal targets below this point. cc_library( diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 462267d1504..c283e3b02c2 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -60,7 +60,13 @@ class RetvalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { - tc.AddRetval(index_, dtype_, input); + // The core from which a return value is returned depends on the core + // assignment of the input to the retval .Since we can't change the core + // assignment of as this point, create a tuple/get-tuple-element + // combination so that the core will be set on them. + auto tuple_elem = + ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0); + tc.AddRetval(index_, dtype_, tuple_elem); } } } diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc new file mode 100644 index 00000000000..d9c839b6101 --- /dev/null +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/tf2xla/sharding_util.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +static const char DEVICE_SUFFIX_REPLICATED_CORE[] = "REPLICATED_CORE"; + +static Status CoreOutOfRangeError(int core, int num_cores_per_replica) { + return errors::InvalidArgument( + "Invalid replicated core id: ", core, + "; num_cores_per_replica=", num_cores_per_replica); +} + +xla::StatusOr> +ParseShardingFromDevice(const string& device_name, int num_cores_per_replica) { + if (device_name.empty()) { + return tensorflow::gtl::optional(); + } + + DeviceNameUtils::ParsedName parsed_device; + if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) { + return errors::InvalidArgument("Malformed assigned device '", device_name, + "'"); + } + if (!parsed_device.has_type || + !StringPiece(parsed_device.type) + .ends_with(DEVICE_SUFFIX_REPLICATED_CORE)) { + return tensorflow::gtl::optional(); + } else { + const int core = parsed_device.id; + if (core < 0 || core >= num_cores_per_replica) { + return CoreOutOfRangeError(core, num_cores_per_replica); + } + return tensorflow::gtl::optional( + xla::ShardingBuilder::AssignDevice(core)); + } +} + +xla::StatusOr> +ParseShardingFromDevice(const Node& node, int num_cores_per_replica) { + string device_name = node.assigned_device_name(); + if (device_name.empty()) { + device_name = node.requested_device(); + } + return ParseShardingFromDevice(device_name, num_cores_per_replica); +} +void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) { + string device_name = src.assigned_device_name(); + if (device_name.empty()) { + device_name = src.requested_device(); + } + dst->set_assigned_device_name(device_name); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h new file mode 100644 index 00000000000..f6468bba9f9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ + +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Parses the op sharding from the 'replicated core' device_name . +// Returns an error: +// - if the device name is invalid. +// - the core is parsed and is out of the range [0, num_cores_per_replica). +// +// Otherwise, returns either a non-value or a sharding set as per +// xla:ShardingBuilder::AssignDevice. +xla::StatusOr> +ParseShardingFromDevice(const string& device_name, int num_cores_per_replica); + +xla::StatusOr> +ParseShardingFromDevice(const Node& node, int num_cores_per_replica); + +void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/sharding_util_test.cc b/tensorflow/compiler/tf2xla/sharding_util_test.cc new file mode 100644 index 00000000000..bff5978237a --- /dev/null +++ b/tensorflow/compiler/tf2xla/sharding_util_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/tf2xla/sharding_util.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(CoreUtilTest, ParseShardingFromDevice) { + Graph graph(OpRegistry::Global()); + + auto core_from_sharding = + [](tensorflow::gtl::optional sharding) -> int64 { + if (sharding.has_value() && + sharding.value().type() == + xla::OpSharding::Type::OpSharding_Type_MAXIMAL) { + return sharding.value().tile_assignment_devices(0); + } else { + return -1; + } + }; + + auto parse_status = ParseShardingFromDevice("", 1); + TF_EXPECT_OK(parse_status.status()); + EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie())); + parse_status = ParseShardingFromDevice("", 100); + TF_EXPECT_OK(parse_status.status()); + EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie())); + + parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:-1", 100); + EXPECT_FALSE(parse_status.ok()); + + parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:55", 100); + TF_EXPECT_OK(parse_status.status()); + EXPECT_EQ(55, core_from_sharding(parse_status.ValueOrDie())); + + parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:100", 100); + EXPECT_FALSE(parse_status.ok()); + + parse_status = ParseShardingFromDevice("/cpu:0", 100); + TF_EXPECT_OK(parse_status.status()); + EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie())); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index fc866a4c0a3..7478feb409a 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/core/common_runtime/local_device.h" @@ -97,23 +98,19 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, metadata.set_op_name(op_kernel->name()); b->SetOpMetadata(metadata); - DeviceNameUtils::ParsedName parsed; - OP_REQUIRES( - context, - DeviceNameUtils::ParseFullName(op_kernel->requested_device(), &parsed), - errors::Internal("Unable to parse device name: ", - op_kernel->requested_device())); - // If no device ID assignment is found, XLA is free to use whatever device it - // wants. In practice this usually has the effect of placing things on - // device 0. - if (parsed.has_id) { - b->SetSharding(xla::ShardingBuilder::AssignDevice(parsed.id)); - } + auto sharding_parse_result = ParseShardingFromDevice( + op_kernel->requested_device(), std::numeric_limits::max()); + OP_REQUIRES_OK(context, sharding_parse_result.status()); + tensorflow::gtl::optional op_sharding = + sharding_parse_result.ValueOrDie(); + // If no sharding metadata is found, XLA is free to use whatever device it + // wants. In practice this usually has the effect of placing things on device + // 0. + xla::ScopedShardingAssignment assign_sharding(b, op_sharding); op_kernel->Compute(context); b->ClearOpMetadata(); - b->ClearSharding(); VLOG(4) << "Done"; } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index e49663b8b04..a215254d2e0 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -160,10 +161,10 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { return graph; } -Status XlaCompiler::CompileFunction( - const XlaCompiler::CompileOptions& options, const NameAttrList& function, - const std::vector& args, - XlaCompiler::CompilationResult* result) { +Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, + const NameAttrList& function, + std::vector args, + XlaCompiler::CompilationResult* result) { const string function_id = Canonicalize(function.name(), AttrSlice(&function.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; @@ -241,13 +242,15 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. -Status BuildArguments(const std::vector& args, +Status BuildArguments(const Graph& graph, + const std::vector& args, bool use_tuple_arg, xla::ComputationBuilder* builder, - XlaContext* context, + XlaContext* context, std::vector* arg_cores, std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes) { arg_expressions->resize(args.size()); + *arg_cores = std::vector(args.size(), -1); // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. @@ -302,6 +305,27 @@ Status BuildArguments(const std::vector& args, (*input_mapping)[i] = parameters[i]; } + // Use the _Arg nodes in the graph to resolve core assignments. + for (const Node* n : graph.nodes()) { + if (StringPiece(n->type_string()) != "_Arg") continue; + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0 && index < args.size()) + << "_Arg out of bounds: " << index << " vs " << args.size(); + TF_ASSIGN_OR_RETURN( + auto sharding, + ParseShardingFromDevice(*n, std::numeric_limits::max())); + if (sharding.has_value()) { + TF_RET_CHECK(sharding.value().type() == + xla::OpSharding::Type::OpSharding_Type_MAXIMAL); + const int core = sharding.value().tile_assignment_devices(0); + if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) { + (*arg_cores)[index] = core; + } + break; + } + } + // Build parameter handles for non-constant arguments. std::vector arg_handles(parameters.size()); if (use_tuple_arg) { @@ -309,10 +333,18 @@ Status BuildArguments(const std::vector& args, xla::ComputationDataHandle tuple = builder->Parameter(0, tuple_shape, "arg_tuple"); for (std::vector::size_type i = 0; i < parameters.size(); ++i) { + const int core = (*arg_cores)[parameters[i]]; + xla::ScopedShardingAssignment assign_sharding( + builder, core == -1 ? tensorflow::gtl::optional() + : xla::ShardingBuilder::AssignDevice(core)); arg_handles[i] = builder->GetTupleElement(tuple, i); } } else { for (std::vector::size_type i = 0; i < parameters.size(); ++i) { + const int core = (*arg_cores)[parameters[i]]; + xla::ScopedShardingAssignment assign_sharding( + builder, core == -1 ? tensorflow::gtl::optional() + : xla::ShardingBuilder::AssignDevice(core)); arg_handles[i] = builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i)); } @@ -368,6 +400,7 @@ Status BuildArguments(const std::vector& args, // type of the final output. Status BuildComputation( const std::vector& args, + const std::vector& arg_cores, const std::vector& retvals, const std::vector>& resources, bool return_updated_values_for_all_resources, @@ -398,6 +431,8 @@ Status BuildComputation( for (const XlaResource* resource : arg_resources) { const XlaCompiler::Argument& arg = args[resource->arg_num]; + const int core = arg_cores[resource->arg_num]; + DCHECK_LT(resource->arg_num, arg_cores.size()); bool modified = resource->value.handle() != resource->initial_value.handle(); // TensorArray gradients were modified if their values changed or there are @@ -417,8 +452,21 @@ Status BuildComputation( for (const auto& grad : resource->tensor_array_gradients) { update.tensor_array_gradients_accessed.insert(grad.first); } + + // Request that the value be returned on a specific core. + xla::ScopedShardingAssignment assign_sharding( + builder, core == -1 ? tensorflow::gtl::optional() + : xla::ShardingBuilder::AssignDevice(core)); + xla::ComputationDataHandle handle; TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); + + // Since we can't change the sharding metadata of as this point, + // create a tuple/get-tuple-element combination so that sharding + // assignment will be placed on this value, which will cause the resource + // update to be returned from the same device that provided the resource. + handle = builder->GetTupleElement(builder->Tuple({handle}), 0); + elems.push_back(handle); } } @@ -479,9 +527,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->tuple_arg = options.use_tuple_arg; std::vector arg_expressions; + std::vector arg_cores; TF_RETURN_IF_ERROR(BuildArguments( - args, options.use_tuple_arg, &builder, context, &arg_expressions, - &result->input_mapping, &result->xla_input_shapes)); + *graph, args, options.use_tuple_arg, &builder, context, &arg_cores, + &arg_expressions, &result->input_mapping, &result->xla_input_shapes)); context->set_args(std::move(arg_expressions)); TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, @@ -491,7 +540,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_computation_outputs; result->computation = std::make_shared(); TF_RETURN_IF_ERROR(BuildComputation( - args, context->retvals(), context->resources(), + args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->resource_updates)); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index a8882a638ca..4d40ca5825a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -255,8 +255,7 @@ class XlaCompiler { Status CompileFunction(const CompileOptions& options, const NameAttrList& fn_name_attrs, - const std::vector& args, - CompilationResult* result); + std::vector args, CompilationResult* result); // Compiles a tensorflow::Graph into an xla::Computation. // Similar to CompileFunction, but takes a Graph as input rather than a diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index d2821749479..bc7ad06a3fe 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -129,14 +129,18 @@ class ComputationBuilder { metadata_.Clear(); } - // Sets an OpDeviceAssignment that will be attached to all instructions - // until cleared. + // Sets an OpSharding that will be attached to all instructions until cleared. void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } - // Clears the device assignment. Ops will be placed according to the default - // placement policy. + // Clears the sharding. Ops will be sharded according to the default placement + // policy. void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } + // Returns the OpSharding that will be attached to all instructions. + const tensorflow::gtl::optional& sharding() const { + return sharding_; + } + // Sets the builder to a mode where it will die immediately when an error is // encountered, rather than producing it in a deferred fashion when Build() is // called (which is the default). @@ -1038,6 +1042,33 @@ ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( return ConstantFromArray(values); } +// RAII-style object: sets the current sharding assignment in builder on +// construction, and sets back to the previous assignment on destruction. +class ScopedShardingAssignment { + public: + ScopedShardingAssignment(xla::ComputationBuilder* builder, + tensorflow::gtl::optional sharding) + : builder_(builder), prev_sharding_(builder->sharding()) { + SetSharding(sharding); + } + + ~ScopedShardingAssignment() { SetSharding(prev_sharding_); } + + private: + void SetSharding(const tensorflow::gtl::optional& sharding) { + if (sharding.has_value()) { + builder_->SetSharding(sharding.value()); + } else { + builder_->ClearSharding(); + } + } + + xla::ComputationBuilder* const builder_; + tensorflow::gtl::optional prev_sharding_; + + TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment); +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 338a4304f32..d521297d994 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -319,8 +319,11 @@ def replicate(computation, # because the TPUReplicatedInput/TPUReplicatedOutput operator would not # be rewritten away, leading to a runtime error. # TODO(phawkins): extend the rewrite to elide these nodes instead. - with ops.device(core(0)): - output_tensors = [array_ops.identity(x) for x in output_tensors] + new_output_tensors = [] + for t in output_tensors: + with ops.device(t.device if t.device else core(0)): + new_output_tensors.append(array_ops.identity(t)) + output_tensors = new_output_tensors finally: context.Exit()