diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 6abc9e268e3..81785b2d89b 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -95,6 +95,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name, kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -112,6 +113,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name, kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -135,6 +137,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps, kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -167,6 +170,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope, kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -183,6 +187,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError) kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -200,6 +205,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel, kernel_label_(kernel_label), device_(other.impl()->device_), assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -217,6 +223,7 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate, kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_( clear_colocations ? std::unordered_set() @@ -237,6 +244,25 @@ Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice, kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), assigned_device_(assigned_device), + xla_cluster_(other.impl()->xla_cluster_), + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} + +Scope::Impl::Impl(const Scope& other, Tags::XlaCluster, + const string& xla_cluster) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(other.impl()->name_map_), + refiner_(other.impl()->refiner_), + scope_used_(other.impl()->scope_used_), + control_deps_(other.impl()->control_deps_), + name_(other.impl()->name_), + op_name_(other.impl()->op_name_), + exit_on_error_(other.impl()->exit_on_error_), + kernel_label_(other.impl()->kernel_label_), + device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), + xla_cluster_(xla_cluster), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -326,6 +352,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const { if (!impl()->assigned_device_.empty()) { builder->AssignedDevice(impl()->assigned_device_); } + if (!impl()->xla_cluster_.empty()) { + builder->XlaCluster(impl()->xla_cluster_); + } } string Scope::Impl::GetUniqueName(const string& prefix, @@ -388,7 +417,7 @@ Scope Scope::NewSubScope(const string& child_scope_name) const { false /* copy_names */)); } -Scope Scope::WithOpName(const string& op_name) const { +Scope Scope::WithOpNameImpl(const string& op_name) const { if (impl()->single_use_scope()) { UpdateStatus(errors::InvalidArgument("Cannot set op name ", op_name, " on this scope")); @@ -425,6 +454,10 @@ Scope Scope::WithAssignedDevice(const string& assigned_device) const { return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device)); } +Scope Scope::WithXlaCluster(const string& xla_cluster) const { + return Scope(new Impl(*this, Impl::Tags::XlaCluster(), xla_cluster)); +} + Scope Scope::ColocateWith(const Operation& op) const { return Scope(new Impl(*this, Impl::Tags::Colocate(), op, /* clear_colocations */ false)); diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index e307d8989b6..0a75f23725c 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -69,8 +70,9 @@ struct CompositeOpScopes; /// // W will be named "linear/W" /// auto W = Variable(linear.WithOpName("W"), /// {2, 2}, DT_FLOAT); -/// // b will be named "linear/b" -/// auto b = Variable(linear.WithOpName("b"), +/// // b will be named "linear/b_3" +/// int idx = 3; +/// auto b = Variable(linear.WithOpName("b_", idx), /// {2}, DT_FLOAT); /// auto x = Const(linear, {...}); // name: "linear/Const" /// auto m = MatMul(linear, x, W); // name: "linear/MatMul" @@ -113,8 +115,11 @@ class Scope { Scope NewSubScope(const string& child_scope_name) const; /// Return a new scope. All ops created within the returned scope will have - /// names of the form `name/op_name[_suffix]`. - Scope WithOpName(const string& op_name) const; + /// names of the form `name/StrCat(fragments...)[_suffix]` + template + Scope WithOpName(Ty... fragments) const { + return WithOpNameImpl(absl::StrCat(fragments...)); + } /// Return a new scope. All ops created within the returned scope will have as /// control dependencies the union of operations in the control_deps vector @@ -137,6 +142,10 @@ class Scope { /// their assigned device set to `assigned_device`. Scope WithAssignedDevice(const string& assigned_device) const; + /// Returns a new scope. All ops created within the returned scope will have + /// their _XlaCluster attribute set to `xla_cluster`. + Scope WithXlaCluster(const string& xla_cluster) const; + /// Return a new scope. All ops created within the returned scope will be /// co-located on the device where op is placed. /// NOTE: This function is intended to be use internal libraries only for @@ -227,6 +236,8 @@ class Scope { // END_SKIP_DOXYGEN private: + Scope WithOpNameImpl(const string& op_name) const; + friend class InternalScope; std::unique_ptr impl_; explicit Scope(Impl*); diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 514e02e8414..5db7eab2b81 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -61,6 +61,7 @@ class Scope::Impl { enum class KernelLabel; enum class Colocate; enum class AssignedDevice; + enum class XlaCluster; }; Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner, @@ -78,6 +79,7 @@ class Scope::Impl { Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op, bool clear_colocations); Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device); + Impl(const Scope& other, Tags::XlaCluster, const string& xla_cluster); std::unordered_set GetColocationConstraints( const Operation& colocate_with_op) const; @@ -112,6 +114,7 @@ class Scope::Impl { const string kernel_label_ = ""; const string device_ = ""; const string assigned_device_ = ""; + const string xla_cluster_ = ""; const std::unordered_set colocation_constraints_; // If true, Scope::DoShapeInference() always returns Status:OK(). diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index b920849a9f8..a6d8adddf77 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -448,6 +448,7 @@ cc_library( "encapsulate_subgraphs_pass.cc", "encapsulate_xla_computations_pass.cc", "extract_outside_compilation_pass.cc", + "increase_dynamism_for_auto_jit_pass.cc", "mark_for_compilation_pass.cc", "mark_for_compilation_pass_test_helper.cc", "partially_decluster_pass.cc", @@ -458,6 +459,7 @@ cc_library( "encapsulate_subgraphs_pass.h", "encapsulate_xla_computations_pass.h", "extract_outside_compilation_pass.h", + "increase_dynamism_for_auto_jit_pass.h", "mark_for_compilation_pass.h", "mark_for_compilation_pass_test_helper.h", "partially_decluster_pass.h", @@ -480,6 +482,7 @@ cc_library( "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", + "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", @@ -493,6 +496,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -576,6 +580,7 @@ tf_cc_test( "encapsulate_subgraphs_pass_test.cc", "encapsulate_xla_computations_pass_test.cc", "extract_outside_compilation_pass_test.cc", + "increase_dynamism_for_auto_jit_pass_test.cc", "mark_for_compilation_pass_test.cc", "partially_decluster_pass_test.cc", ], diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc new file mode 100644 index 00000000000..f7ffa2589ae --- /dev/null +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -0,0 +1,331 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace { +Status GetTensorFromConstOp(Node* n, Tensor* out_tensor) { + TF_RET_CHECK(n->type_string() == "Const"); + const TensorProto* proto = nullptr; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto)); + Tensor tensor(proto->dtype()); + TF_RET_CHECK(tensor.FromProto(*proto)); + *out_tensor = std::move(tensor); + return Status::OK(); +} + +struct SliceInputs { + Output slice_op; + Output input; + Output begin; + Output size; + + // The size of the TF slice operation as a std::vector. We can always compute + // this because we only manipulate slices with a Const size. + std::vector size_as_vector; +}; + +std::vector IntTensorAsVector(const Tensor& t) { + DCHECK(t.dtype() == DT_INT32 || t.dtype() == DT_INT64); + std::vector result; + result.reserve(t.NumElements()); + for (int i = 0; i < t.NumElements(); i++) { + int64 element = t.dtype() == DT_INT32 + ? static_cast(t.flat()(i)) + : t.flat()(i); + result.push_back(element); + } + return result; +} + +// Packages up the inputs to a Slice operation into an instance of +// `SliceInputs`. +Status GetSliceInputs(Node* slice, SliceInputs* slice_inputs) { + const int kSliceInputIndex = 0; + const int kSliceBeginIndex = 1; + const int kSliceSizeIndex = 2; + + const Edge* slice_input_edge; + TF_RETURN_IF_ERROR(slice->input_edge(kSliceInputIndex, &slice_input_edge)); + const Edge* slice_size_edge; + TF_RETURN_IF_ERROR(slice->input_edge(kSliceSizeIndex, &slice_size_edge)); + const Edge* slice_begin_edge; + TF_RETURN_IF_ERROR(slice->input_edge(kSliceBeginIndex, &slice_begin_edge)); + slice_inputs->input = + Output(slice_input_edge->src(), slice_input_edge->src_output()); + slice_inputs->begin = + Output(slice_begin_edge->src(), slice_begin_edge->src_output()); + slice_inputs->size = + Output(slice_size_edge->src(), slice_size_edge->src_output()); + + Tensor tf_slice_size; + TF_RETURN_IF_ERROR( + GetTensorFromConstOp(slice_inputs->size.node(), &tf_slice_size)); + + if (tf_slice_size.dims() != 1) { + return errors::Internal("Expected vector for the slice size input."); + } + + slice_inputs->size_as_vector = IntTensorAsVector(tf_slice_size); + return Status::OK(); +} + +// Casts `x` to a DT_INT64 if it isn't one already. +Output MakeInt64(const Scope& host_scope, absl::string_view name, + const Output& x) { + return x.type() == DT_INT64 + ? x + : ops::Cast(host_scope.WithOpName(name, "_s64"), x, DT_INT64); +} + +// Returns `slice_inputs` with the index and size inputs cast to DT_INT64. +SliceInputs MakeSliceIndexAndSizeInt64(const Scope& host_scope, + const SliceInputs& slice_inputs) { + SliceInputs result; + result.input = slice_inputs.input; + result.begin = MakeInt64(host_scope, "begin", slice_inputs.begin); + result.size = MakeInt64(host_scope, "size", slice_inputs.size); + result.size_as_vector = slice_inputs.size_as_vector; + return result; +} + +// This class caches emitted constants to avoid creating multiple nodes for the +// same constant value. This helps make the generated GraphDef more readable. +class ConstantCache { + public: + explicit ConstantCache(const Scope& s) : scope_(s) {} + + Output Get1DHostConstant(int64 constant) { + auto it = cache_.find(constant); + if (it == cache_.end()) { + Output new_const = + ops::Const(scope_.WithOpName("const_", constant), {constant}); + it = cache_.insert({constant, new_const}).first; + } + return it->second; + } + + private: + Scope scope_; + std::unordered_map cache_; +}; + +// Returns a node computing the size of the Slice op with inputs `slice_inputs`. +Status ComputeSliceSize(const Scope& host_scope, + const SliceInputs& slice_inputs, Output* size) { + // If slice_size[i] >= 0 then slice_size[i] = slice_size[i]. + // + // If slice_size[i] == -1 then slice_size[i] = input_size[i] - + // begin[i]. + // + // If slice_size[i] < -1 then executing the slice will throw an error, and we + // don't do anything here. We've already filtered these cases out in + // IsRewritableSlice. + + if (absl::c_all_of(slice_inputs.size_as_vector, + [](int64 i) { return i >= 0; })) { + *size = slice_inputs.size; + return Status::OK(); + } + + Output input_shape = + ops::Shape(host_scope.WithOpName("input_shape"), slice_inputs.input, + ops::Shape::OutType(DT_INT64)); + + ConstantCache constant_pool(host_scope); + + std::vector slice_size; + for (int i = 0; i < slice_inputs.size_as_vector.size(); i++) { + if (slice_inputs.size_as_vector[i] >= 0) { + slice_size.push_back( + constant_pool.Get1DHostConstant(slice_inputs.size_as_vector[i])); + continue; + } + + DCHECK_EQ(slice_inputs.size_as_vector[i], -1); + + Output begin_i = ops::Slice( + host_scope.WithOpName("begin_", i), slice_inputs.begin, + constant_pool.Get1DHostConstant(i), constant_pool.Get1DHostConstant(1)); + + Output input_shape_i = ops::Slice( + host_scope.WithOpName("input_shape_", i), input_shape, + constant_pool.Get1DHostConstant(i), constant_pool.Get1DHostConstant(1)); + + slice_size.push_back(ops::Sub(host_scope.WithOpName("slice_size_", i), + input_shape_i, begin_i)); + DCHECK_EQ(slice_size.back().type(), DT_INT64); + } + + *size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size, + ops::Const(host_scope.WithOpName("concat_axis"), 0)); + return Status::OK(); +} + +// Terminology: "static sized" slice is a slice with the +// _XlaCompileTimeConstantInputs attribute set to {2}. The output shape of +// these slices can be solely determined by their "size" input. +Status ConvertTensorFlowSliceToStaticShapedSlice( + Graph* g, Node* slice, const SliceInputs& slice_inputs, + absl::string_view cluster_name, Node** result) { + string host_name; + TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( + slice->assigned_device_name(), &host_name)); + + Status status; + Scope main_scope = + NewInternalScope(g, &status, /*refiner=*/nullptr) + .WithXlaCluster(string(cluster_name)) + .NewSubScope(absl::StrCat(slice->name(), "/static_shaped_slice")); + Scope host_scope = main_scope.WithAssignedDevice(host_name); + + SliceInputs slice_inputs_int64 = + MakeSliceIndexAndSizeInt64(host_scope, slice_inputs); + + Output slice_size; + TF_RETURN_IF_ERROR( + ComputeSliceSize(host_scope, slice_inputs_int64, &slice_size)); + + *result = + ops::Slice(main_scope.WithAssignedDevice(slice->assigned_device_name()) + .WithOpName("static_shaped_slice"), + slice_inputs_int64.input, slice_inputs_int64.begin, slice_size) + .node(); + std::vector compile_time_const_inputs; + compile_time_const_inputs.push_back(2); + (*result)->AddAttr(kXlaCompileTimeConstantInputsAttr, + compile_time_const_inputs); + return status; +} + +void ReplaceTensorFlowSliceWithStaticShapedSlice(Graph* g, Node* slice, + Node* static_shaped_slice) { + absl::InlinedVector edges_to_remove; + std::vector slice_out_edges; + absl::c_copy(slice->out_edges(), std::back_inserter(slice_out_edges)); + for (const Edge* e : slice_out_edges) { + DCHECK(e->src_output() == 0 || e->src_output() == Graph::kControlSlot); + + int src_output = e->src_output(); + int dst_input = e->dst_input(); + Node* dst = e->dst(); + g->RemoveEdge(e); + g->AddEdge(static_shaped_slice, src_output, dst, dst_input); + } + + for (const Edge* e : slice->in_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(e->src(), static_shaped_slice); + } + } + + g->RemoveNode(slice); +} + +Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs, + absl::string_view cluster_name) { + VLOG(3) << "Rewriting slice " << slice->name() + << " to a \"static shaped\" Slice"; + Node* static_shaped_slice; + TF_RETURN_IF_ERROR(ConvertTensorFlowSliceToStaticShapedSlice( + g, slice, slice_inputs, cluster_name, &static_shaped_slice)); + ReplaceTensorFlowSliceWithStaticShapedSlice(g, slice, static_shaped_slice); + return Status::OK(); +} + +// Returns true if `n` is a slice we can rewrite to have a static shape +// (i.e. have the output shape only depend on the "size" input). Fills in +// `slice_inputs` in the process. +bool IsRewritableSlice(Node* n, SliceInputs* slice_inputs) { + if (n->type_string() != "Slice") { + return false; + } + + if (!GetXlaClusterForNode(*n).has_value()) { + // There is no need to change slice ops outside XLA clusters. + return false; + } + + if (!GetSliceInputs(n, slice_inputs).ok()) { + // Could not parse slice inputs. E.g. the sizes input was not a constant. + return false; + } + + // If slice_size[i] < -1 for any i then executing the slice will throw an + // error, and we don't do anything here. + return absl::c_all_of(slice_inputs->size_as_vector, + [](int64 size_i) { return size_i >= -1; }); +} + +Status FindAndRewriteSlices(Graph* g, bool* changed) { + std::vector> slices_to_rewrite; + for (Node* n : g->nodes()) { + SliceInputs slice_inputs; + if (IsRewritableSlice(n, &slice_inputs)) { + slices_to_rewrite.push_back({n, std::move(slice_inputs)}); + } + } + + for (const auto& pair : slices_to_rewrite) { + TF_RETURN_IF_ERROR(RewriteSlice(g, pair.first, pair.second, + *GetXlaClusterForNode(*pair.first))); + } + + if (!slices_to_rewrite.empty()) { + // We've added constants to the graph; hook them up to _SOURCE. + FixupSourceAndSinkEdges(g); + } + + *changed = !slices_to_rewrite.empty(); + + return Status::OK(); +} +} // namespace + +Status IncreaseDynamismForAutoJitPass::Run( + const GraphOptimizationPassOptions& options) { + bool changed; + TF_RETURN_IF_ERROR(FindAndRewriteSlices(options.graph->get(), &changed)); + if (changed) { + legacy_flags::MarkForCompilationPassFlags* flags = + legacy_flags::GetMarkForCompilationPassFlags(); + if (flags->tf_xla_clustering_debug) { + dump_graph::DumpGraphToFile("increase_dynamism_for_auto_jit_pass", + **options.graph, options.flib_def); + } + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h new file mode 100644 index 00000000000..818ca948d64 --- /dev/null +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// Increases the amount of "dynamism" representable by XLA clusters by rewriting +// the TensorFlow graph. This pass does the following rewrites: +// +// Slice +// ----- +// +// Slice(op, begin, size ) => +// Slice(op, begin, actual_size(op.shape(), size, begin)); +// _XlaCompileTimeConstantInputs={2} +// +// where +// +// actual_size(op_shape, size, begin)[i] = +// size[i] == -1 ? (op_shape[i] - size[i]) +// : size[i] +// +// This pass, combined with jit/partially_decluster_pass, reduces the number of +// unnecessary cluster recompilations in some common cases. After the rewrite +// shown above jit/partially_decluster_pass extracts the actual_size(...) +// computation to outside the XLA cluster, causing the cluster to be versioned +// only on the actual size of the XlaDynamicSlice. This avoids recompilation +// due to superficial changes that don't affect tensor shapes. +// +// Future Work TODO(b/111210515) +// ----------------------------- +// +// In the future we will also translate StridedSlice and Pad a similar way. +class IncreaseDynamismForAutoJitPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_ diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc new file mode 100644 index 00000000000..06cd7cf2dd7 --- /dev/null +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -0,0 +1,285 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/compiler/jit/node_matchers.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { + +using testing::matchers::AssignedDevice; +using testing::matchers::Attr; +using testing::matchers::Const; +using testing::matchers::CtrlDeps; +using testing::matchers::Inputs; +using testing::matchers::Name; +using testing::matchers::NodeWith; +using testing::matchers::Op; +using testing::matchers::Out; + +// A fake device used to populate a DeviceSet. +class FakeDevice : public Device { + public: + explicit FakeDevice(const DeviceAttributes& device_attributes) + : Device(nullptr, device_attributes) {} + + Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } + + Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } + + static std::unique_ptr Make(const string& name, const string& type) { + DeviceAttributes device_attributes; + device_attributes.set_name(name); + device_attributes.set_device_type(DeviceType(type).type()); + return absl::make_unique(device_attributes); + } +}; + +const char* kHostName = "/job:worker/replica:0/task:0/device:CPU:0"; +const char* kDeviceName = "/job:worker/replica:0/task:0/device:GPU:0"; + +Status IncreaseDynamismForAutoJit(const Scope& s, + std::unique_ptr* result) { + std::vector> devices; + devices.push_back(FakeDevice::Make(kDeviceName, DEVICE_GPU)); + devices.push_back(FakeDevice::Make(kHostName, DEVICE_CPU)); + + std::unique_ptr device_set(new DeviceSet()); + for (auto& device : devices) { + device_set->AddDevice(device.get()); + } + + auto graph = absl::make_unique(OpRegistry::Global()); + SessionOptions session_options; + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_global_jit_level(OptimizerOptions::ON_2); + GraphOptimizationPassOptions options; + options.graph = &graph; + options.device_set = device_set.get(); + options.session_options = &session_options; + + // Scope::ToGraph seems to drop assigned devices, probably because it goes + // through a GraphDef. So explicitly maintain the device assignment. + std::unordered_map assigned_device_names; + for (Node* n : s.graph()->nodes()) { + assigned_device_names[n->name()] = n->assigned_device_name(); + } + TF_RETURN_IF_ERROR(s.ToGraph(graph.get())); + for (Node* n : graph->nodes()) { + n->set_assigned_device_name(assigned_device_names[n->name()]); + } + + IncreaseDynamismForAutoJitPass rewriter; + TF_RETURN_IF_ERROR(rewriter.Run(options)); + *result = std::move(graph); + return Status::OK(); +} + +TEST(SliceToDynamicSliceRewriteTest, Basic) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size = ops::Const(root.WithOpName("size"), {-1, 500}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + const int64 zero_64 = 0; + const int32 zero_32 = 0; + const int64 one_64 = 1; + + auto m_input = Out(NodeWith(Op("Placeholder"), Name("input"))); + auto m_begin_s64 = Out(NodeWith( + Op("Cast"), Inputs(Out(NodeWith(Op("Placeholder"), Name("begin")))))); + auto m_input_shape = Out(NodeWith(Op("Shape"), Inputs(m_input))); + auto m_slice_size_0 = Out(NodeWith( + Op("Sub"), AssignedDevice(kHostName), + Inputs( + Out(NodeWith(Op("Slice"), AssignedDevice(kHostName), + Inputs(m_input_shape, Const(zero_64), Const(one_64)))), + Out(NodeWith(Op("Slice"), AssignedDevice(kHostName), + Inputs(m_begin_s64, Const(zero_64), Const(one_64))))))); + auto m_dynamic_slice_size = Out(NodeWith( + Op("ConcatV2"), AssignedDevice(kHostName), + Inputs(m_slice_size_0, Const(static_cast(500)), Const(zero_32)))); + + std::vector compile_time_constant_inputs; + compile_time_constant_inputs.push_back(2); + auto m_dynamic_slice = NodeWith( + Op("Slice"), AssignedDevice(kDeviceName), + Attr(kXlaCompileTimeConstantInputsAttr, compile_time_constant_inputs), + Inputs(m_input, m_begin_s64, m_dynamic_slice_size)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), "slice/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_THAT(static_shaped_slice, m_dynamic_slice); +} + +TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size = ops::Const(root.WithOpName("size"), {-1, 500}); + Output control_pred = ops::Placeholder(root.WithOpName("control"), DT_BOOL); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + root.graph()->AddControlEdge(control_pred.node(), slice.node()); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), "slice/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_THAT(static_shaped_slice, + NodeWith(Op("Slice"), + CtrlDeps(NodeWith(Op("Placeholder"), Name("control"))))); +} + +TEST(SliceToDynamicSliceRewriteTest, Int64Indices) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + auto to_int64 = [](int v) { return static_cast(v); }; + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); + Output size = + ops::Const(root.WithOpName("size"), {to_int64(-1), to_int64(500)}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("Cast"))))); +} + +TEST(SliceToDynamicSliceRewriteTest, DontRewriteInvalidSlice) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + + // The shape refiner throws an error if we use a bogus constant value for + // size. So we first use a Placeholder to placate the shape refiner, and + // later replace it with a bogus constant. + Output size_placeholder = + ops::Placeholder(root.WithOpName("size_placeholder"), DT_INT32); + Output slice = + ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder); + + Output size = ops::Const(root.WithOpName("size"), {-8, 500}); + TF_ASSERT_OK(root.graph()->UpdateEdge(/*new_src=*/size.node(), + /*new_src_index=*/0, + /*dst=*/slice.node(), /*dst_index=*/2)); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + EXPECT_THAT(result->nodes(), + Not(Contains(NodeWith(Op("Slice"), + Attr(kXlaCompileTimeConstantInputsAttr))))); +} + +TEST(SliceToDynamicSliceRewriteTest, DontRewriteUnclusteredSlice) { + Scope root = + Scope::NewRootScope().ExitOnError().WithAssignedDevice(kDeviceName); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size = ops::Const(root.WithOpName("size"), {-1, 500}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + EXPECT_THAT(result->nodes(), + Not(Contains(NodeWith(Op("Slice"), + Attr(kXlaCompileTimeConstantInputsAttr))))); +} + +TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithNonConstSize) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); + Output size = ops::Placeholder(root.WithOpName("size"), DT_INT64); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + EXPECT_THAT(result->nodes(), + Not(Contains(NodeWith(Op("Slice"), + Attr(kXlaCompileTimeConstantInputsAttr))))); +} + +TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + auto to_int64 = [](int v) { return static_cast(v); }; + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); + + // The C++ node bindings immediately error out when we try construct a bogus + // slice so we first use a placeholder to construct the Slice and then replace + // the input. + Output size_placeholder = ops::Placeholder(root.WithOpName("size"), DT_INT64); + Output slice = + ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder); + + Output size = + ops::Const(root.WithOpName("size"), {{to_int64(-1)}, {to_int64(500)}}); + TF_ASSERT_OK(root.graph()->UpdateEdge(size.node(), 0, slice.node(), 2)); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + EXPECT_THAT(result->nodes(), + Not(Contains(NodeWith(Op("Slice"), + Attr(kXlaCompileTimeConstantInputsAttr))))); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 085c0e5adbb..f79bdc1e2e8 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" +#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" @@ -44,17 +45,20 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, + IncreaseDynamismForAutoJitPass); + +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, PartiallyDeclusterPass); // The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We // also need to run it after the graph been rewritten to have _Send nodes added // for fetches. Before the _Send nodes are added, fetch nodes are identified by // name, and encapsulation might remove that node from the graph. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40, EncapsulateSubgraphsPass); // Must run after EncapsulateSubgraphsPass. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50, BuildXlaOpsPass); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc index a09a6eb1553..95b8ace306a 100644 --- a/tensorflow/compiler/jit/node_matchers.cc +++ b/tensorflow/compiler/jit/node_matchers.cc @@ -204,11 +204,12 @@ struct NodeMatcher : public ::testing::MatcherInterface { } return false; } - if (!AreAttrValuesEqual(it->second, attr_kv_pair.second)) { + if (attr_kv_pair.second && + !AreAttrValuesEqual(it->second, *attr_kv_pair.second)) { if (listener->IsInterested()) { *listener << "attribute named " << attr_kv_pair.first << " does not match value; expected: \"" - << SummarizeAttrValue(attr_kv_pair.second) + << SummarizeAttrValue(*attr_kv_pair.second) << "\", found: \"" << SummarizeAttrValue(it->second) << "\""; } @@ -278,12 +279,14 @@ struct NodeMatcher : public ::testing::MatcherInterface { if (!attrs.empty()) { printed_something = true; std::vector attrs_str; - absl::c_transform(attrs, std::back_inserter(attrs_str), - [](const std::pair& attr_kv_pair) { - return absl::StrCat( - attr_kv_pair.first, "->", - SummarizeAttrValue(attr_kv_pair.second)); - }); + absl::c_transform( + attrs, std::back_inserter(attrs_str), + [](const std::pair>& attr_kv_pair) { + return absl::StrCat(attr_kv_pair.first, "->", + attr_kv_pair.second + ? SummarizeAttrValue(*attr_kv_pair.second) + : "*"); + }); *os << " and attr values matching [" << absl::StrJoin(attrs_str, ", ") << "]"; } @@ -327,7 +330,7 @@ struct NodeMatcher : public ::testing::MatcherInterface { absl::optional>> input_matchers; absl::optional<::testing::Matcher>> control_dep_set; - std::map attrs; + std::map> attrs; }; // Matches a dst and dst_output on an input edge. Today we only use this with @@ -472,12 +475,28 @@ std::pair impl::AttrLiteralHelper( return {bool_attr.first, attr_value}; } +std::pair impl::AttrLiteralHelper( + const std::pair>& int_list_attr) { + AttrValue attr_value; + AttrValue::ListValue* list = attr_value.mutable_list(); + for (int i : int_list_attr.second) { + list->add_i(i); + } + return {int_list_attr.first, attr_value}; +} + impl::NodeMatcherProperties impl::Attr(std::pair attr) { impl::NodeMatcherProperties props; props.set_attr(std::move(attr)); return props; } +impl::NodeMatcherProperties impl::Attr(string name) { + impl::NodeMatcherProperties props; + props.set_attr({std::move(name), absl::nullopt}); + return props; +} + NodeMatcherProperties ConstantValue( const ::tensorflow::Input::Initializer& val) { TF_CHECK_OK(val.status); @@ -486,9 +505,9 @@ NodeMatcherProperties ConstantValue( return props; } -::testing::Matcher Const( +::testing::Matcher Const( const ::tensorflow::Input::Initializer& val) { - return NodeWith(ConstantValue(val)); + return Out(NodeWith(ConstantValue(val))); } ::testing::Matcher Out( int oidx, ::testing::Matcher node_matcher) { diff --git a/tensorflow/compiler/jit/node_matchers.h b/tensorflow/compiler/jit/node_matchers.h index 35c2f5fd7b5..cd2ab53e404 100644 --- a/tensorflow/compiler/jit/node_matchers.h +++ b/tensorflow/compiler/jit/node_matchers.h @@ -84,7 +84,7 @@ class NodeMatcherProperties { public: using NodeSeqMatcher = std::vector<::testing::Matcher>; using InputSeqMatcher = std::vector<::testing::Matcher>; - using AttrKeyValuePair = std::pair; + using AttrKeyValuePair = std::pair>; const absl::optional& name() const { return name_; } const absl::optional& op() const { return op_; } @@ -163,9 +163,13 @@ impl::NodeMatcherProperties CtrlDeps( absl::Span> control_deps); impl::NodeMatcherProperties Attr(std::pair attrs); +impl::NodeMatcherProperties Attr(string name); std::pair AttrLiteralHelper( const std::pair& bool_attr); + +std::pair AttrLiteralHelper( + const std::pair>& int_list_attr); } // namespace impl // ----------------------------------------------------------------------------- @@ -187,6 +191,10 @@ impl::NodeMatcherProperties Attr(const string& name, ValueTy value) { return impl::Attr({impl::AttrLiteralHelper({name, value})}); } +inline impl::NodeMatcherProperties Attr(const string& name) { + return impl::Attr(name); +} + // Matches a node with inputs `inputs`. // // `inputs` are ordered; `inputs`[i] must match input i. @@ -200,7 +208,8 @@ impl::NodeMatcherProperties Inputs(Ts... inputs) { ::testing::Matcher node); // Matches the first output of a node that matches `node`. -::testing::Matcher Out(::testing::Matcher node) { +inline ::testing::Matcher Out( + ::testing::Matcher node) { return Out(0, node); } @@ -224,7 +233,7 @@ template return impl::NodeWith(array); } -::testing::Matcher Const( +::testing::Matcher Const( const ::tensorflow::Input::Initializer& val); } // namespace matchers diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index 05c6943f2c7..adaee479359 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -104,6 +104,11 @@ NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) { return *this; } +NodeBuilder& NodeBuilder::XlaCluster(StringPiece xla_cluster) { + def_builder_.Attr("_XlaCluster", xla_cluster); + return *this; +} + Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const { // In case of error, set *created_node to nullptr. if (created_node != nullptr) *created_node = nullptr; diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h index d576985a232..31fb5909393 100644 --- a/tensorflow/core/graph/node_builder.h +++ b/tensorflow/core/graph/node_builder.h @@ -103,6 +103,9 @@ class NodeBuilder { // Sets the device name in the "assigned device" field in tensorflow::Node. NodeBuilder& AssignedDevice(StringPiece device); + // Sets the _XlaCluster attribute in created node to `xla_cluster`. + NodeBuilder& XlaCluster(StringPiece xla_cluster); + // Set the value of an attr. attr_name must match the name of one of // attrs defined by the Op, and value must have the corresponding type // (see SetAttrValue() in ../framework/attr_value_util.h for legal diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc index 8c24076aa9c..cb088faec1e 100644 --- a/tensorflow/core/util/device_name_utils.cc +++ b/tensorflow/core/util/device_name_utils.cc @@ -480,4 +480,16 @@ std::vector DeviceNameUtils::GetLocalNamesForDeviceMappings( } } +/*static*/ Status DeviceNameUtils::DeviceNameToCpuDeviceName( + const string& device_name, string* host_device_name) { + DeviceNameUtils::ParsedName device; + if (!DeviceNameUtils::ParseFullName(device_name, &device)) { + return errors::Internal("Could not parse device name ", device_name); + } + device.type = "CPU"; + device.id = 0; + *host_device_name = DeviceNameUtils::ParsedNameToString(device); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h index 3f0bc605623..bb5e2b3f0c4 100644 --- a/tensorflow/core/util/device_name_utils.h +++ b/tensorflow/core/util/device_name_utils.h @@ -169,6 +169,11 @@ class DeviceNameUtils { // mapping. static std::vector GetLocalNamesForDeviceMappings( const ParsedName& pn); + + // Returns name of the CPU:0 device on the same host as the device + // `device_name`. + static Status DeviceNameToCpuDeviceName(const string& device_name, + string* host_device_name); }; } // namespace tensorflow