Add compiler/tf2xla/sharding_util.h with utilities for getting the core device from

a Node.

PiperOrigin-RevId: 174133602
This commit is contained in:
A. Unique TensorFlower 2017-10-31 20:49:18 -07:00 committed by TensorFlower Gardener
parent ab4349a26c
commit 27412f3b64
10 changed files with 321 additions and 31 deletions

View File

@ -123,6 +123,7 @@ cc_library(
":const_analysis",
":dump_graph",
":functionalize_control_flow",
":sharding_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@ -169,6 +170,36 @@ cc_library(
],
)
cc_library(
name = "sharding_util",
srcs = ["sharding_util.cc"],
hdrs = ["sharding_util.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
tf_cc_test(
name = "sharding_util_test",
srcs = ["sharding_util_test.cc"],
deps = [
":sharding_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# Internal targets below this point.
cc_library(

View File

@ -60,7 +60,13 @@ class RetvalOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
} else {
tc.AddRetval(index_, dtype_, input);
// The core from which a return value is returned depends on the core
// assignment of the input to the retval .Since we can't change the core
// assignment of <input> as this point, create a tuple/get-tuple-element
// combination so that the core will be set on them.
auto tuple_elem =
ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0);
tc.AddRetval(index_, dtype_, tuple_elem);
}
}
}

View File

@ -0,0 +1,72 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
static const char DEVICE_SUFFIX_REPLICATED_CORE[] = "REPLICATED_CORE";
static Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
return errors::InvalidArgument(
"Invalid replicated core id: ", core,
"; num_cores_per_replica=", num_cores_per_replica);
}
xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
ParseShardingFromDevice(const string& device_name, int num_cores_per_replica) {
if (device_name.empty()) {
return tensorflow::gtl::optional<xla::OpSharding>();
}
DeviceNameUtils::ParsedName parsed_device;
if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) {
return errors::InvalidArgument("Malformed assigned device '", device_name,
"'");
}
if (!parsed_device.has_type ||
!StringPiece(parsed_device.type)
.ends_with(DEVICE_SUFFIX_REPLICATED_CORE)) {
return tensorflow::gtl::optional<xla::OpSharding>();
} else {
const int core = parsed_device.id;
if (core < 0 || core >= num_cores_per_replica) {
return CoreOutOfRangeError(core, num_cores_per_replica);
}
return tensorflow::gtl::optional<xla::OpSharding>(
xla::ShardingBuilder::AssignDevice(core));
}
}
xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
ParseShardingFromDevice(const Node& node, int num_cores_per_replica) {
string device_name = node.assigned_device_name();
if (device_name.empty()) {
device_name = node.requested_device();
}
return ParseShardingFromDevice(device_name, num_cores_per_replica);
}
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
string device_name = src.assigned_device_name();
if (device_name.empty()) {
device_name = src.requested_device();
}
dst->set_assigned_device_name(device_name);
}
} // namespace tensorflow

View File

@ -0,0 +1,44 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_
#include <string>
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Parses the op sharding from the 'replicated core' device_name <device_name>.
// Returns an error:
// - if the device name is invalid.
// - the core is parsed and is out of the range [0, num_cores_per_replica).
//
// Otherwise, returns either a non-value or a sharding set as per
// xla:ShardingBuilder::AssignDevice.
xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
ParseShardingFromDevice(const string& device_name, int num_cores_per_replica);
xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
ParseShardingFromDevice(const Node& node, int num_cores_per_replica);
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_

View File

@ -0,0 +1,58 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
TEST(CoreUtilTest, ParseShardingFromDevice) {
Graph graph(OpRegistry::Global());
auto core_from_sharding =
[](tensorflow::gtl::optional<xla::OpSharding> sharding) -> int64 {
if (sharding.has_value() &&
sharding.value().type() ==
xla::OpSharding::Type::OpSharding_Type_MAXIMAL) {
return sharding.value().tile_assignment_devices(0);
} else {
return -1;
}
};
auto parse_status = ParseShardingFromDevice("", 1);
TF_EXPECT_OK(parse_status.status());
EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie()));
parse_status = ParseShardingFromDevice("", 100);
TF_EXPECT_OK(parse_status.status());
EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie()));
parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:-1", 100);
EXPECT_FALSE(parse_status.ok());
parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:55", 100);
TF_EXPECT_OK(parse_status.status());
EXPECT_EQ(55, core_from_sharding(parse_status.ValueOrDie()));
parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:100", 100);
EXPECT_FALSE(parse_status.ok());
parse_status = ParseShardingFromDevice("/cpu:0", 100);
TF_EXPECT_OK(parse_status.status());
EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie()));
}
} // namespace tensorflow

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/core/common_runtime/local_device.h"
@ -97,23 +98,19 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel,
metadata.set_op_name(op_kernel->name());
b->SetOpMetadata(metadata);
DeviceNameUtils::ParsedName parsed;
OP_REQUIRES(
context,
DeviceNameUtils::ParseFullName(op_kernel->requested_device(), &parsed),
errors::Internal("Unable to parse device name: ",
op_kernel->requested_device()));
// If no device ID assignment is found, XLA is free to use whatever device it
// wants. In practice this usually has the effect of placing things on
// device 0.
if (parsed.has_id) {
b->SetSharding(xla::ShardingBuilder::AssignDevice(parsed.id));
}
auto sharding_parse_result = ParseShardingFromDevice(
op_kernel->requested_device(), std::numeric_limits<int>::max());
OP_REQUIRES_OK(context, sharding_parse_result.status());
tensorflow::gtl::optional<xla::OpSharding> op_sharding =
sharding_parse_result.ValueOrDie();
// If no sharding metadata is found, XLA is free to use whatever device it
// wants. In practice this usually has the effect of placing things on device
// 0.
xla::ScopedShardingAssignment assign_sharding(b, op_sharding);
op_kernel->Compute(context);
b->ClearOpMetadata();
b->ClearSharding();
VLOG(4) << "Done";
}

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
@ -160,10 +161,10 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
return graph;
}
Status XlaCompiler::CompileFunction(
const XlaCompiler::CompileOptions& options, const NameAttrList& function,
const std::vector<XlaCompiler::Argument>& args,
XlaCompiler::CompilationResult* result) {
Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
const NameAttrList& function,
std::vector<XlaCompiler::Argument> args,
XlaCompiler::CompilationResult* result) {
const string function_id =
Canonicalize(function.name(), AttrSlice(&function.attr()));
VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
@ -241,13 +242,15 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
// Builds XLA computations for each of the arguments to the computation.
// `args` are the arguments to the computation.
Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
Status BuildArguments(const Graph& graph,
const std::vector<XlaCompiler::Argument>& args,
bool use_tuple_arg, xla::ComputationBuilder* builder,
XlaContext* context,
XlaContext* context, std::vector<int>* arg_cores,
std::vector<XlaExpression>* arg_expressions,
std::vector<int>* input_mapping,
std::vector<xla::Shape>* input_shapes) {
arg_expressions->resize(args.size());
*arg_cores = std::vector<int>(args.size(), -1);
// Argument numbers of arguments and resources that are to be passed to the
// XLA computation as runtime parameters.
@ -302,6 +305,27 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
(*input_mapping)[i] = parameters[i];
}
// Use the _Arg nodes in the graph to resolve core assignments.
for (const Node* n : graph.nodes()) {
if (StringPiece(n->type_string()) != "_Arg") continue;
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
TF_RET_CHECK(index >= 0 && index < args.size())
<< "_Arg out of bounds: " << index << " vs " << args.size();
TF_ASSIGN_OR_RETURN(
auto sharding,
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
if (sharding.has_value()) {
TF_RET_CHECK(sharding.value().type() ==
xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
const int core = sharding.value().tile_assignment_devices(0);
if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) {
(*arg_cores)[index] = core;
}
break;
}
}
// Build parameter handles for non-constant arguments.
std::vector<xla::ComputationDataHandle> arg_handles(parameters.size());
if (use_tuple_arg) {
@ -309,10 +333,18 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
xla::ComputationDataHandle tuple =
builder->Parameter(0, tuple_shape, "arg_tuple");
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
const int core = (*arg_cores)[parameters[i]];
xla::ScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::ShardingBuilder::AssignDevice(core));
arg_handles[i] = builder->GetTupleElement(tuple, i);
}
} else {
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
const int core = (*arg_cores)[parameters[i]];
xla::ScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::ShardingBuilder::AssignDevice(core));
arg_handles[i] =
builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i));
}
@ -368,6 +400,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
// type of the final output.
Status BuildComputation(
const std::vector<XlaCompiler::Argument>& args,
const std::vector<int>& arg_cores,
const std::vector<XlaExpression>& retvals,
const std::vector<std::unique_ptr<XlaResource>>& resources,
bool return_updated_values_for_all_resources,
@ -398,6 +431,8 @@ Status BuildComputation(
for (const XlaResource* resource : arg_resources) {
const XlaCompiler::Argument& arg = args[resource->arg_num];
const int core = arg_cores[resource->arg_num];
DCHECK_LT(resource->arg_num, arg_cores.size());
bool modified =
resource->value.handle() != resource->initial_value.handle();
// TensorArray gradients were modified if their values changed or there are
@ -417,8 +452,21 @@ Status BuildComputation(
for (const auto& grad : resource->tensor_array_gradients) {
update.tensor_array_gradients_accessed.insert(grad.first);
}
// Request that the value be returned on a specific core.
xla::ScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::ShardingBuilder::AssignDevice(core));
xla::ComputationDataHandle handle;
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
// Since we can't change the sharding metadata of <value> as this point,
// create a tuple/get-tuple-element combination so that sharding
// assignment will be placed on this value, which will cause the resource
// update to be returned from the same device that provided the resource.
handle = builder->GetTupleElement(builder->Tuple({handle}), 0);
elems.push_back(handle);
}
}
@ -479,9 +527,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
result->tuple_arg = options.use_tuple_arg;
std::vector<XlaExpression> arg_expressions;
std::vector<int> arg_cores;
TF_RETURN_IF_ERROR(BuildArguments(
args, options.use_tuple_arg, &builder, context, &arg_expressions,
&result->input_mapping, &result->xla_input_shapes));
*graph, args, options.use_tuple_arg, &builder, context, &arg_cores,
&arg_expressions, &result->input_mapping, &result->xla_input_shapes));
context->set_args(std::move(arg_expressions));
TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
@ -491,7 +540,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
int num_computation_outputs;
result->computation = std::make_shared<xla::Computation>();
TF_RETURN_IF_ERROR(BuildComputation(
args, context->retvals(), context->resources(),
args, arg_cores, context->retvals(), context->resources(),
options.return_updated_values_for_all_resources, &builder,
result->computation.get(), &num_computation_outputs,
&num_nonconst_outputs, &result->resource_updates));

View File

@ -255,8 +255,7 @@ class XlaCompiler {
Status CompileFunction(const CompileOptions& options,
const NameAttrList& fn_name_attrs,
const std::vector<Argument>& args,
CompilationResult* result);
std::vector<Argument> args, CompilationResult* result);
// Compiles a tensorflow::Graph into an xla::Computation.
// Similar to CompileFunction, but takes a Graph as input rather than a

View File

@ -129,14 +129,18 @@ class ComputationBuilder {
metadata_.Clear();
}
// Sets an OpDeviceAssignment that will be attached to all instructions
// until cleared.
// Sets an OpSharding that will be attached to all instructions until cleared.
void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
// Clears the device assignment. Ops will be placed according to the default
// placement policy.
// Clears the sharding. Ops will be sharded according to the default placement
// policy.
void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
// Returns the OpSharding that will be attached to all instructions.
const tensorflow::gtl::optional<OpSharding>& sharding() const {
return sharding_;
}
// Sets the builder to a mode where it will die immediately when an error is
// encountered, rather than producing it in a deferred fashion when Build() is
// called (which is the default).
@ -1038,6 +1042,33 @@ ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
return ConstantFromArray(values);
}
// RAII-style object: sets the current sharding assignment in builder on
// construction, and sets back to the previous assignment on destruction.
class ScopedShardingAssignment {
public:
ScopedShardingAssignment(xla::ComputationBuilder* builder,
tensorflow::gtl::optional<OpSharding> sharding)
: builder_(builder), prev_sharding_(builder->sharding()) {
SetSharding(sharding);
}
~ScopedShardingAssignment() { SetSharding(prev_sharding_); }
private:
void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) {
if (sharding.has_value()) {
builder_->SetSharding(sharding.value());
} else {
builder_->ClearSharding();
}
}
xla::ComputationBuilder* const builder_;
tensorflow::gtl::optional<OpSharding> prev_sharding_;
TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment);
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_

View File

@ -319,8 +319,11 @@ def replicate(computation,
# because the TPUReplicatedInput/TPUReplicatedOutput operator would not
# be rewritten away, leading to a runtime error.
# TODO(phawkins): extend the rewrite to elide these nodes instead.
with ops.device(core(0)):
output_tensors = [array_ops.identity(x) for x in output_tensors]
new_output_tensors = []
for t in output_tensors:
with ops.device(t.device if t.device else core(0)):
new_output_tensors.append(array_ops.identity(t))
output_tensors = new_output_tensors
finally:
context.Exit()