Add compiler/tf2xla/sharding_util.h with utilities for getting the core device from
a Node. PiperOrigin-RevId: 174133602
This commit is contained in:
parent
ab4349a26c
commit
27412f3b64
@ -123,6 +123,7 @@ cc_library(
|
||||
":const_analysis",
|
||||
":dump_graph",
|
||||
":functionalize_control_flow",
|
||||
":sharding_util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -169,6 +170,36 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sharding_util",
|
||||
srcs = ["sharding_util.cc"],
|
||||
hdrs = ["sharding_util.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "sharding_util_test",
|
||||
srcs = ["sharding_util_test.cc"],
|
||||
deps = [
|
||||
":sharding_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# Internal targets below this point.
|
||||
|
||||
cc_library(
|
||||
|
@ -60,7 +60,13 @@ class RetvalOp : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
|
||||
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
|
||||
} else {
|
||||
tc.AddRetval(index_, dtype_, input);
|
||||
// The core from which a return value is returned depends on the core
|
||||
// assignment of the input to the retval .Since we can't change the core
|
||||
// assignment of <input> as this point, create a tuple/get-tuple-element
|
||||
// combination so that the core will be set on them.
|
||||
auto tuple_elem =
|
||||
ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0);
|
||||
tc.AddRetval(index_, dtype_, tuple_elem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
72
tensorflow/compiler/tf2xla/sharding_util.cc
Normal file
72
tensorflow/compiler/tf2xla/sharding_util.cc
Normal file
@ -0,0 +1,72 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static const char DEVICE_SUFFIX_REPLICATED_CORE[] = "REPLICATED_CORE";
|
||||
|
||||
static Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid replicated core id: ", core,
|
||||
"; num_cores_per_replica=", num_cores_per_replica);
|
||||
}
|
||||
|
||||
xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
|
||||
ParseShardingFromDevice(const string& device_name, int num_cores_per_replica) {
|
||||
if (device_name.empty()) {
|
||||
return tensorflow::gtl::optional<xla::OpSharding>();
|
||||
}
|
||||
|
||||
DeviceNameUtils::ParsedName parsed_device;
|
||||
if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) {
|
||||
return errors::InvalidArgument("Malformed assigned device '", device_name,
|
||||
"'");
|
||||
}
|
||||
if (!parsed_device.has_type ||
|
||||
!StringPiece(parsed_device.type)
|
||||
.ends_with(DEVICE_SUFFIX_REPLICATED_CORE)) {
|
||||
return tensorflow::gtl::optional<xla::OpSharding>();
|
||||
} else {
|
||||
const int core = parsed_device.id;
|
||||
if (core < 0 || core >= num_cores_per_replica) {
|
||||
return CoreOutOfRangeError(core, num_cores_per_replica);
|
||||
}
|
||||
return tensorflow::gtl::optional<xla::OpSharding>(
|
||||
xla::ShardingBuilder::AssignDevice(core));
|
||||
}
|
||||
}
|
||||
|
||||
xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
|
||||
ParseShardingFromDevice(const Node& node, int num_cores_per_replica) {
|
||||
string device_name = node.assigned_device_name();
|
||||
if (device_name.empty()) {
|
||||
device_name = node.requested_device();
|
||||
}
|
||||
return ParseShardingFromDevice(device_name, num_cores_per_replica);
|
||||
}
|
||||
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
|
||||
string device_name = src.assigned_device_name();
|
||||
if (device_name.empty()) {
|
||||
device_name = src.requested_device();
|
||||
}
|
||||
dst->set_assigned_device_name(device_name);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
44
tensorflow/compiler/tf2xla/sharding_util.h
Normal file
44
tensorflow/compiler/tf2xla/sharding_util.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Parses the op sharding from the 'replicated core' device_name <device_name>.
|
||||
// Returns an error:
|
||||
// - if the device name is invalid.
|
||||
// - the core is parsed and is out of the range [0, num_cores_per_replica).
|
||||
//
|
||||
// Otherwise, returns either a non-value or a sharding set as per
|
||||
// xla:ShardingBuilder::AssignDevice.
|
||||
xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
|
||||
ParseShardingFromDevice(const string& device_name, int num_cores_per_replica);
|
||||
|
||||
xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
|
||||
ParseShardingFromDevice(const Node& node, int num_cores_per_replica);
|
||||
|
||||
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_
|
58
tensorflow/compiler/tf2xla/sharding_util_test.cc
Normal file
58
tensorflow/compiler/tf2xla/sharding_util_test.cc
Normal file
@ -0,0 +1,58 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(CoreUtilTest, ParseShardingFromDevice) {
|
||||
Graph graph(OpRegistry::Global());
|
||||
|
||||
auto core_from_sharding =
|
||||
[](tensorflow::gtl::optional<xla::OpSharding> sharding) -> int64 {
|
||||
if (sharding.has_value() &&
|
||||
sharding.value().type() ==
|
||||
xla::OpSharding::Type::OpSharding_Type_MAXIMAL) {
|
||||
return sharding.value().tile_assignment_devices(0);
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
auto parse_status = ParseShardingFromDevice("", 1);
|
||||
TF_EXPECT_OK(parse_status.status());
|
||||
EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie()));
|
||||
parse_status = ParseShardingFromDevice("", 100);
|
||||
TF_EXPECT_OK(parse_status.status());
|
||||
EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie()));
|
||||
|
||||
parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:-1", 100);
|
||||
EXPECT_FALSE(parse_status.ok());
|
||||
|
||||
parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:55", 100);
|
||||
TF_EXPECT_OK(parse_status.status());
|
||||
EXPECT_EQ(55, core_from_sharding(parse_status.ValueOrDie()));
|
||||
|
||||
parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:100", 100);
|
||||
EXPECT_FALSE(parse_status.ok());
|
||||
|
||||
parse_status = ParseShardingFromDevice("/cpu:0", 100);
|
||||
TF_EXPECT_OK(parse_status.status());
|
||||
EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/core/common_runtime/local_device.h"
|
||||
@ -97,23 +98,19 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel,
|
||||
metadata.set_op_name(op_kernel->name());
|
||||
b->SetOpMetadata(metadata);
|
||||
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
DeviceNameUtils::ParseFullName(op_kernel->requested_device(), &parsed),
|
||||
errors::Internal("Unable to parse device name: ",
|
||||
op_kernel->requested_device()));
|
||||
// If no device ID assignment is found, XLA is free to use whatever device it
|
||||
// wants. In practice this usually has the effect of placing things on
|
||||
// device 0.
|
||||
if (parsed.has_id) {
|
||||
b->SetSharding(xla::ShardingBuilder::AssignDevice(parsed.id));
|
||||
}
|
||||
auto sharding_parse_result = ParseShardingFromDevice(
|
||||
op_kernel->requested_device(), std::numeric_limits<int>::max());
|
||||
OP_REQUIRES_OK(context, sharding_parse_result.status());
|
||||
tensorflow::gtl::optional<xla::OpSharding> op_sharding =
|
||||
sharding_parse_result.ValueOrDie();
|
||||
|
||||
// If no sharding metadata is found, XLA is free to use whatever device it
|
||||
// wants. In practice this usually has the effect of placing things on device
|
||||
// 0.
|
||||
xla::ScopedShardingAssignment assign_sharding(b, op_sharding);
|
||||
op_kernel->Compute(context);
|
||||
|
||||
b->ClearOpMetadata();
|
||||
b->ClearSharding();
|
||||
VLOG(4) << "Done";
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
|
||||
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
@ -160,10 +161,10 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
||||
return graph;
|
||||
}
|
||||
|
||||
Status XlaCompiler::CompileFunction(
|
||||
const XlaCompiler::CompileOptions& options, const NameAttrList& function,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
XlaCompiler::CompilationResult* result) {
|
||||
Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
|
||||
const NameAttrList& function,
|
||||
std::vector<XlaCompiler::Argument> args,
|
||||
XlaCompiler::CompilationResult* result) {
|
||||
const string function_id =
|
||||
Canonicalize(function.name(), AttrSlice(&function.attr()));
|
||||
VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
|
||||
@ -241,13 +242,15 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||
|
||||
// Builds XLA computations for each of the arguments to the computation.
|
||||
// `args` are the arguments to the computation.
|
||||
Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
||||
Status BuildArguments(const Graph& graph,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
bool use_tuple_arg, xla::ComputationBuilder* builder,
|
||||
XlaContext* context,
|
||||
XlaContext* context, std::vector<int>* arg_cores,
|
||||
std::vector<XlaExpression>* arg_expressions,
|
||||
std::vector<int>* input_mapping,
|
||||
std::vector<xla::Shape>* input_shapes) {
|
||||
arg_expressions->resize(args.size());
|
||||
*arg_cores = std::vector<int>(args.size(), -1);
|
||||
|
||||
// Argument numbers of arguments and resources that are to be passed to the
|
||||
// XLA computation as runtime parameters.
|
||||
@ -302,6 +305,27 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
||||
(*input_mapping)[i] = parameters[i];
|
||||
}
|
||||
|
||||
// Use the _Arg nodes in the graph to resolve core assignments.
|
||||
for (const Node* n : graph.nodes()) {
|
||||
if (StringPiece(n->type_string()) != "_Arg") continue;
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
TF_RET_CHECK(index >= 0 && index < args.size())
|
||||
<< "_Arg out of bounds: " << index << " vs " << args.size();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sharding,
|
||||
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
|
||||
if (sharding.has_value()) {
|
||||
TF_RET_CHECK(sharding.value().type() ==
|
||||
xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
|
||||
const int core = sharding.value().tile_assignment_devices(0);
|
||||
if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) {
|
||||
(*arg_cores)[index] = core;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Build parameter handles for non-constant arguments.
|
||||
std::vector<xla::ComputationDataHandle> arg_handles(parameters.size());
|
||||
if (use_tuple_arg) {
|
||||
@ -309,10 +333,18 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
||||
xla::ComputationDataHandle tuple =
|
||||
builder->Parameter(0, tuple_shape, "arg_tuple");
|
||||
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
|
||||
const int core = (*arg_cores)[parameters[i]];
|
||||
xla::ScopedShardingAssignment assign_sharding(
|
||||
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
|
||||
: xla::ShardingBuilder::AssignDevice(core));
|
||||
arg_handles[i] = builder->GetTupleElement(tuple, i);
|
||||
}
|
||||
} else {
|
||||
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
|
||||
const int core = (*arg_cores)[parameters[i]];
|
||||
xla::ScopedShardingAssignment assign_sharding(
|
||||
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
|
||||
: xla::ShardingBuilder::AssignDevice(core));
|
||||
arg_handles[i] =
|
||||
builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i));
|
||||
}
|
||||
@ -368,6 +400,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
|
||||
// type of the final output.
|
||||
Status BuildComputation(
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
const std::vector<int>& arg_cores,
|
||||
const std::vector<XlaExpression>& retvals,
|
||||
const std::vector<std::unique_ptr<XlaResource>>& resources,
|
||||
bool return_updated_values_for_all_resources,
|
||||
@ -398,6 +431,8 @@ Status BuildComputation(
|
||||
|
||||
for (const XlaResource* resource : arg_resources) {
|
||||
const XlaCompiler::Argument& arg = args[resource->arg_num];
|
||||
const int core = arg_cores[resource->arg_num];
|
||||
DCHECK_LT(resource->arg_num, arg_cores.size());
|
||||
bool modified =
|
||||
resource->value.handle() != resource->initial_value.handle();
|
||||
// TensorArray gradients were modified if their values changed or there are
|
||||
@ -417,8 +452,21 @@ Status BuildComputation(
|
||||
for (const auto& grad : resource->tensor_array_gradients) {
|
||||
update.tensor_array_gradients_accessed.insert(grad.first);
|
||||
}
|
||||
|
||||
// Request that the value be returned on a specific core.
|
||||
xla::ScopedShardingAssignment assign_sharding(
|
||||
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
|
||||
: xla::ShardingBuilder::AssignDevice(core));
|
||||
|
||||
xla::ComputationDataHandle handle;
|
||||
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
|
||||
|
||||
// Since we can't change the sharding metadata of <value> as this point,
|
||||
// create a tuple/get-tuple-element combination so that sharding
|
||||
// assignment will be placed on this value, which will cause the resource
|
||||
// update to be returned from the same device that provided the resource.
|
||||
handle = builder->GetTupleElement(builder->Tuple({handle}), 0);
|
||||
|
||||
elems.push_back(handle);
|
||||
}
|
||||
}
|
||||
@ -479,9 +527,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
result->tuple_arg = options.use_tuple_arg;
|
||||
|
||||
std::vector<XlaExpression> arg_expressions;
|
||||
std::vector<int> arg_cores;
|
||||
TF_RETURN_IF_ERROR(BuildArguments(
|
||||
args, options.use_tuple_arg, &builder, context, &arg_expressions,
|
||||
&result->input_mapping, &result->xla_input_shapes));
|
||||
*graph, args, options.use_tuple_arg, &builder, context, &arg_cores,
|
||||
&arg_expressions, &result->input_mapping, &result->xla_input_shapes));
|
||||
context->set_args(std::move(arg_expressions));
|
||||
|
||||
TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
|
||||
@ -491,7 +540,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
int num_computation_outputs;
|
||||
result->computation = std::make_shared<xla::Computation>();
|
||||
TF_RETURN_IF_ERROR(BuildComputation(
|
||||
args, context->retvals(), context->resources(),
|
||||
args, arg_cores, context->retvals(), context->resources(),
|
||||
options.return_updated_values_for_all_resources, &builder,
|
||||
result->computation.get(), &num_computation_outputs,
|
||||
&num_nonconst_outputs, &result->resource_updates));
|
||||
|
@ -255,8 +255,7 @@ class XlaCompiler {
|
||||
|
||||
Status CompileFunction(const CompileOptions& options,
|
||||
const NameAttrList& fn_name_attrs,
|
||||
const std::vector<Argument>& args,
|
||||
CompilationResult* result);
|
||||
std::vector<Argument> args, CompilationResult* result);
|
||||
|
||||
// Compiles a tensorflow::Graph into an xla::Computation.
|
||||
// Similar to CompileFunction, but takes a Graph as input rather than a
|
||||
|
@ -129,14 +129,18 @@ class ComputationBuilder {
|
||||
metadata_.Clear();
|
||||
}
|
||||
|
||||
// Sets an OpDeviceAssignment that will be attached to all instructions
|
||||
// until cleared.
|
||||
// Sets an OpSharding that will be attached to all instructions until cleared.
|
||||
void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
|
||||
|
||||
// Clears the device assignment. Ops will be placed according to the default
|
||||
// placement policy.
|
||||
// Clears the sharding. Ops will be sharded according to the default placement
|
||||
// policy.
|
||||
void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
|
||||
|
||||
// Returns the OpSharding that will be attached to all instructions.
|
||||
const tensorflow::gtl::optional<OpSharding>& sharding() const {
|
||||
return sharding_;
|
||||
}
|
||||
|
||||
// Sets the builder to a mode where it will die immediately when an error is
|
||||
// encountered, rather than producing it in a deferred fashion when Build() is
|
||||
// called (which is the default).
|
||||
@ -1038,6 +1042,33 @@ ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
|
||||
return ConstantFromArray(values);
|
||||
}
|
||||
|
||||
// RAII-style object: sets the current sharding assignment in builder on
|
||||
// construction, and sets back to the previous assignment on destruction.
|
||||
class ScopedShardingAssignment {
|
||||
public:
|
||||
ScopedShardingAssignment(xla::ComputationBuilder* builder,
|
||||
tensorflow::gtl::optional<OpSharding> sharding)
|
||||
: builder_(builder), prev_sharding_(builder->sharding()) {
|
||||
SetSharding(sharding);
|
||||
}
|
||||
|
||||
~ScopedShardingAssignment() { SetSharding(prev_sharding_); }
|
||||
|
||||
private:
|
||||
void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) {
|
||||
if (sharding.has_value()) {
|
||||
builder_->SetSharding(sharding.value());
|
||||
} else {
|
||||
builder_->ClearSharding();
|
||||
}
|
||||
}
|
||||
|
||||
xla::ComputationBuilder* const builder_;
|
||||
tensorflow::gtl::optional<OpSharding> prev_sharding_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment);
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
|
||||
|
@ -319,8 +319,11 @@ def replicate(computation,
|
||||
# because the TPUReplicatedInput/TPUReplicatedOutput operator would not
|
||||
# be rewritten away, leading to a runtime error.
|
||||
# TODO(phawkins): extend the rewrite to elide these nodes instead.
|
||||
with ops.device(core(0)):
|
||||
output_tensors = [array_ops.identity(x) for x in output_tensors]
|
||||
new_output_tensors = []
|
||||
for t in output_tensors:
|
||||
with ops.device(t.device if t.device else core(0)):
|
||||
new_output_tensors.append(array_ops.identity(t))
|
||||
output_tensors = new_output_tensors
|
||||
finally:
|
||||
context.Exit()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user