Introduce a pass to increase the amount of dynamism supported by an XLA cluster

Increases the amount of dynamism representable by XLA clusters by rewriting the
TensorFlow graph.  See the header for a description.

This pass, combined with jit/partially_decluster_pass, reduces the number of
unnecessary cluster recompilations in some common cases.

The CL is organized as follows:

 - cc/framework/scope* and core/graph/node_builder are modified so that new
   nodes can now be automatically put in an XLA cluster using
   Scope::WithXlaCluster.

 - The pass is implemented in jit/increase_dynamism_for_auto_jit_pass.

 - In jit/jit_compilation_pass_registration The new pass is registered to run
   between MarkForCompilationPass and PartiallyDeclusterPass.

PiperOrigin-RevId: 218907734
This commit is contained in:
Sanjoy Das 2018-10-26 13:48:16 -07:00 committed by TensorFlower Gardener
parent 33d6328d16
commit 2c164ed32f
14 changed files with 803 additions and 21 deletions

View File

@ -95,6 +95,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@ -112,6 +113,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@ -135,6 +137,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@ -167,6 +170,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@ -183,6 +187,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@ -200,6 +205,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
kernel_label_(kernel_label),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@ -217,6 +223,7 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(
clear_colocations
? std::unordered_set<string>()
@ -237,6 +244,25 @@ Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice,
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(assigned_device),
xla_cluster_(other.impl()->xla_cluster_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::XlaCluster,
const string& xla_cluster)
: graph_(other.impl()->graph_),
status_(other.impl()->status_),
name_map_(other.impl()->name_map_),
refiner_(other.impl()->refiner_),
scope_used_(other.impl()->scope_used_),
control_deps_(other.impl()->control_deps_),
name_(other.impl()->name_),
op_name_(other.impl()->op_name_),
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
assigned_device_(other.impl()->assigned_device_),
xla_cluster_(xla_cluster),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@ -326,6 +352,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const {
if (!impl()->assigned_device_.empty()) {
builder->AssignedDevice(impl()->assigned_device_);
}
if (!impl()->xla_cluster_.empty()) {
builder->XlaCluster(impl()->xla_cluster_);
}
}
string Scope::Impl::GetUniqueName(const string& prefix,
@ -388,7 +417,7 @@ Scope Scope::NewSubScope(const string& child_scope_name) const {
false /* copy_names */));
}
Scope Scope::WithOpName(const string& op_name) const {
Scope Scope::WithOpNameImpl(const string& op_name) const {
if (impl()->single_use_scope()) {
UpdateStatus(errors::InvalidArgument("Cannot set op name ", op_name,
" on this scope"));
@ -425,6 +454,10 @@ Scope Scope::WithAssignedDevice(const string& assigned_device) const {
return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device));
}
Scope Scope::WithXlaCluster(const string& xla_cluster) const {
return Scope(new Impl(*this, Impl::Tags::XlaCluster(), xla_cluster));
}
Scope Scope::ColocateWith(const Operation& op) const {
return Scope(new Impl(*this, Impl::Tags::Colocate(), op,
/* clear_colocations */ false));

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@ -69,8 +70,9 @@ struct CompositeOpScopes;
/// // W will be named "linear/W"
/// auto W = Variable(linear.WithOpName("W"),
/// {2, 2}, DT_FLOAT);
/// // b will be named "linear/b"
/// auto b = Variable(linear.WithOpName("b"),
/// // b will be named "linear/b_3"
/// int idx = 3;
/// auto b = Variable(linear.WithOpName("b_", idx),
/// {2}, DT_FLOAT);
/// auto x = Const(linear, {...}); // name: "linear/Const"
/// auto m = MatMul(linear, x, W); // name: "linear/MatMul"
@ -113,8 +115,11 @@ class Scope {
Scope NewSubScope(const string& child_scope_name) const;
/// Return a new scope. All ops created within the returned scope will have
/// names of the form `name/op_name[_suffix]`.
Scope WithOpName(const string& op_name) const;
/// names of the form `name/StrCat(fragments...)[_suffix]`
template <typename... Ty>
Scope WithOpName(Ty... fragments) const {
return WithOpNameImpl(absl::StrCat(fragments...));
}
/// Return a new scope. All ops created within the returned scope will have as
/// control dependencies the union of operations in the control_deps vector
@ -137,6 +142,10 @@ class Scope {
/// their assigned device set to `assigned_device`.
Scope WithAssignedDevice(const string& assigned_device) const;
/// Returns a new scope. All ops created within the returned scope will have
/// their _XlaCluster attribute set to `xla_cluster`.
Scope WithXlaCluster(const string& xla_cluster) const;
/// Return a new scope. All ops created within the returned scope will be
/// co-located on the device where op is placed.
/// NOTE: This function is intended to be use internal libraries only for
@ -227,6 +236,8 @@ class Scope {
// END_SKIP_DOXYGEN
private:
Scope WithOpNameImpl(const string& op_name) const;
friend class InternalScope;
std::unique_ptr<Impl> impl_;
explicit Scope(Impl*);

View File

@ -61,6 +61,7 @@ class Scope::Impl {
enum class KernelLabel;
enum class Colocate;
enum class AssignedDevice;
enum class XlaCluster;
};
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner,
@ -78,6 +79,7 @@ class Scope::Impl {
Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op,
bool clear_colocations);
Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device);
Impl(const Scope& other, Tags::XlaCluster, const string& xla_cluster);
std::unordered_set<string> GetColocationConstraints(
const Operation& colocate_with_op) const;
@ -112,6 +114,7 @@ class Scope::Impl {
const string kernel_label_ = "";
const string device_ = "";
const string assigned_device_ = "";
const string xla_cluster_ = "";
const std::unordered_set<string> colocation_constraints_;
// If true, Scope::DoShapeInference() always returns Status:OK().

View File

@ -448,6 +448,7 @@ cc_library(
"encapsulate_subgraphs_pass.cc",
"encapsulate_xla_computations_pass.cc",
"extract_outside_compilation_pass.cc",
"increase_dynamism_for_auto_jit_pass.cc",
"mark_for_compilation_pass.cc",
"mark_for_compilation_pass_test_helper.cc",
"partially_decluster_pass.cc",
@ -458,6 +459,7 @@ cc_library(
"encapsulate_subgraphs_pass.h",
"encapsulate_xla_computations_pass.h",
"extract_outside_compilation_pass.h",
"increase_dynamism_for_auto_jit_pass.h",
"mark_for_compilation_pass.h",
"mark_for_compilation_pass_test_helper.h",
"partially_decluster_pass.h",
@ -480,6 +482,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/tf2xla/cc:xla_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
@ -493,6 +496,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
@ -576,6 +580,7 @@ tf_cc_test(
"encapsulate_subgraphs_pass_test.cc",
"encapsulate_xla_computations_pass_test.cc",
"extract_outside_compilation_pass_test.cc",
"increase_dynamism_for_auto_jit_pass_test.cc",
"mark_for_compilation_pass_test.cc",
"partially_decluster_pass_test.cc",
],

View File

@ -0,0 +1,331 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h"
#include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
namespace {
Status GetTensorFromConstOp(Node* n, Tensor* out_tensor) {
TF_RET_CHECK(n->type_string() == "Const");
const TensorProto* proto = nullptr;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto));
Tensor tensor(proto->dtype());
TF_RET_CHECK(tensor.FromProto(*proto));
*out_tensor = std::move(tensor);
return Status::OK();
}
struct SliceInputs {
Output slice_op;
Output input;
Output begin;
Output size;
// The size of the TF slice operation as a std::vector. We can always compute
// this because we only manipulate slices with a Const size.
std::vector<int64> size_as_vector;
};
std::vector<int64> IntTensorAsVector(const Tensor& t) {
DCHECK(t.dtype() == DT_INT32 || t.dtype() == DT_INT64);
std::vector<int64> result;
result.reserve(t.NumElements());
for (int i = 0; i < t.NumElements(); i++) {
int64 element = t.dtype() == DT_INT32
? static_cast<int64>(t.flat<int32>()(i))
: t.flat<int64>()(i);
result.push_back(element);
}
return result;
}
// Packages up the inputs to a Slice operation into an instance of
// `SliceInputs`.
Status GetSliceInputs(Node* slice, SliceInputs* slice_inputs) {
const int kSliceInputIndex = 0;
const int kSliceBeginIndex = 1;
const int kSliceSizeIndex = 2;
const Edge* slice_input_edge;
TF_RETURN_IF_ERROR(slice->input_edge(kSliceInputIndex, &slice_input_edge));
const Edge* slice_size_edge;
TF_RETURN_IF_ERROR(slice->input_edge(kSliceSizeIndex, &slice_size_edge));
const Edge* slice_begin_edge;
TF_RETURN_IF_ERROR(slice->input_edge(kSliceBeginIndex, &slice_begin_edge));
slice_inputs->input =
Output(slice_input_edge->src(), slice_input_edge->src_output());
slice_inputs->begin =
Output(slice_begin_edge->src(), slice_begin_edge->src_output());
slice_inputs->size =
Output(slice_size_edge->src(), slice_size_edge->src_output());
Tensor tf_slice_size;
TF_RETURN_IF_ERROR(
GetTensorFromConstOp(slice_inputs->size.node(), &tf_slice_size));
if (tf_slice_size.dims() != 1) {
return errors::Internal("Expected vector for the slice size input.");
}
slice_inputs->size_as_vector = IntTensorAsVector(tf_slice_size);
return Status::OK();
}
// Casts `x` to a DT_INT64 if it isn't one already.
Output MakeInt64(const Scope& host_scope, absl::string_view name,
const Output& x) {
return x.type() == DT_INT64
? x
: ops::Cast(host_scope.WithOpName(name, "_s64"), x, DT_INT64);
}
// Returns `slice_inputs` with the index and size inputs cast to DT_INT64.
SliceInputs MakeSliceIndexAndSizeInt64(const Scope& host_scope,
const SliceInputs& slice_inputs) {
SliceInputs result;
result.input = slice_inputs.input;
result.begin = MakeInt64(host_scope, "begin", slice_inputs.begin);
result.size = MakeInt64(host_scope, "size", slice_inputs.size);
result.size_as_vector = slice_inputs.size_as_vector;
return result;
}
// This class caches emitted constants to avoid creating multiple nodes for the
// same constant value. This helps make the generated GraphDef more readable.
class ConstantCache {
public:
explicit ConstantCache(const Scope& s) : scope_(s) {}
Output Get1DHostConstant(int64 constant) {
auto it = cache_.find(constant);
if (it == cache_.end()) {
Output new_const =
ops::Const(scope_.WithOpName("const_", constant), {constant});
it = cache_.insert({constant, new_const}).first;
}
return it->second;
}
private:
Scope scope_;
std::unordered_map<int, Output> cache_;
};
// Returns a node computing the size of the Slice op with inputs `slice_inputs`.
Status ComputeSliceSize(const Scope& host_scope,
const SliceInputs& slice_inputs, Output* size) {
// If slice_size[i] >= 0 then slice_size[i] = slice_size[i].
//
// If slice_size[i] == -1 then slice_size[i] = input_size[i] -
// begin[i].
//
// If slice_size[i] < -1 then executing the slice will throw an error, and we
// don't do anything here. We've already filtered these cases out in
// IsRewritableSlice.
if (absl::c_all_of(slice_inputs.size_as_vector,
[](int64 i) { return i >= 0; })) {
*size = slice_inputs.size;
return Status::OK();
}
Output input_shape =
ops::Shape(host_scope.WithOpName("input_shape"), slice_inputs.input,
ops::Shape::OutType(DT_INT64));
ConstantCache constant_pool(host_scope);
std::vector<Output> slice_size;
for (int i = 0; i < slice_inputs.size_as_vector.size(); i++) {
if (slice_inputs.size_as_vector[i] >= 0) {
slice_size.push_back(
constant_pool.Get1DHostConstant(slice_inputs.size_as_vector[i]));
continue;
}
DCHECK_EQ(slice_inputs.size_as_vector[i], -1);
Output begin_i = ops::Slice(
host_scope.WithOpName("begin_", i), slice_inputs.begin,
constant_pool.Get1DHostConstant(i), constant_pool.Get1DHostConstant(1));
Output input_shape_i = ops::Slice(
host_scope.WithOpName("input_shape_", i), input_shape,
constant_pool.Get1DHostConstant(i), constant_pool.Get1DHostConstant(1));
slice_size.push_back(ops::Sub(host_scope.WithOpName("slice_size_", i),
input_shape_i, begin_i));
DCHECK_EQ(slice_size.back().type(), DT_INT64);
}
*size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size,
ops::Const(host_scope.WithOpName("concat_axis"), 0));
return Status::OK();
}
// Terminology: "static sized" slice is a slice with the
// _XlaCompileTimeConstantInputs attribute set to {2}. The output shape of
// these slices can be solely determined by their "size" input.
Status ConvertTensorFlowSliceToStaticShapedSlice(
Graph* g, Node* slice, const SliceInputs& slice_inputs,
absl::string_view cluster_name, Node** result) {
string host_name;
TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
slice->assigned_device_name(), &host_name));
Status status;
Scope main_scope =
NewInternalScope(g, &status, /*refiner=*/nullptr)
.WithXlaCluster(string(cluster_name))
.NewSubScope(absl::StrCat(slice->name(), "/static_shaped_slice"));
Scope host_scope = main_scope.WithAssignedDevice(host_name);
SliceInputs slice_inputs_int64 =
MakeSliceIndexAndSizeInt64(host_scope, slice_inputs);
Output slice_size;
TF_RETURN_IF_ERROR(
ComputeSliceSize(host_scope, slice_inputs_int64, &slice_size));
*result =
ops::Slice(main_scope.WithAssignedDevice(slice->assigned_device_name())
.WithOpName("static_shaped_slice"),
slice_inputs_int64.input, slice_inputs_int64.begin, slice_size)
.node();
std::vector<int> compile_time_const_inputs;
compile_time_const_inputs.push_back(2);
(*result)->AddAttr(kXlaCompileTimeConstantInputsAttr,
compile_time_const_inputs);
return status;
}
void ReplaceTensorFlowSliceWithStaticShapedSlice(Graph* g, Node* slice,
Node* static_shaped_slice) {
absl::InlinedVector<const Edge*, 6> edges_to_remove;
std::vector<const Edge*> slice_out_edges;
absl::c_copy(slice->out_edges(), std::back_inserter(slice_out_edges));
for (const Edge* e : slice_out_edges) {
DCHECK(e->src_output() == 0 || e->src_output() == Graph::kControlSlot);
int src_output = e->src_output();
int dst_input = e->dst_input();
Node* dst = e->dst();
g->RemoveEdge(e);
g->AddEdge(static_shaped_slice, src_output, dst, dst_input);
}
for (const Edge* e : slice->in_edges()) {
if (e->IsControlEdge()) {
g->AddControlEdge(e->src(), static_shaped_slice);
}
}
g->RemoveNode(slice);
}
Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs,
absl::string_view cluster_name) {
VLOG(3) << "Rewriting slice " << slice->name()
<< " to a \"static shaped\" Slice";
Node* static_shaped_slice;
TF_RETURN_IF_ERROR(ConvertTensorFlowSliceToStaticShapedSlice(
g, slice, slice_inputs, cluster_name, &static_shaped_slice));
ReplaceTensorFlowSliceWithStaticShapedSlice(g, slice, static_shaped_slice);
return Status::OK();
}
// Returns true if `n` is a slice we can rewrite to have a static shape
// (i.e. have the output shape only depend on the "size" input). Fills in
// `slice_inputs` in the process.
bool IsRewritableSlice(Node* n, SliceInputs* slice_inputs) {
if (n->type_string() != "Slice") {
return false;
}
if (!GetXlaClusterForNode(*n).has_value()) {
// There is no need to change slice ops outside XLA clusters.
return false;
}
if (!GetSliceInputs(n, slice_inputs).ok()) {
// Could not parse slice inputs. E.g. the sizes input was not a constant.
return false;
}
// If slice_size[i] < -1 for any i then executing the slice will throw an
// error, and we don't do anything here.
return absl::c_all_of(slice_inputs->size_as_vector,
[](int64 size_i) { return size_i >= -1; });
}
Status FindAndRewriteSlices(Graph* g, bool* changed) {
std::vector<std::pair<Node*, SliceInputs>> slices_to_rewrite;
for (Node* n : g->nodes()) {
SliceInputs slice_inputs;
if (IsRewritableSlice(n, &slice_inputs)) {
slices_to_rewrite.push_back({n, std::move(slice_inputs)});
}
}
for (const auto& pair : slices_to_rewrite) {
TF_RETURN_IF_ERROR(RewriteSlice(g, pair.first, pair.second,
*GetXlaClusterForNode(*pair.first)));
}
if (!slices_to_rewrite.empty()) {
// We've added constants to the graph; hook them up to _SOURCE.
FixupSourceAndSinkEdges(g);
}
*changed = !slices_to_rewrite.empty();
return Status::OK();
}
} // namespace
Status IncreaseDynamismForAutoJitPass::Run(
const GraphOptimizationPassOptions& options) {
bool changed;
TF_RETURN_IF_ERROR(FindAndRewriteSlices(options.graph->get(), &changed));
if (changed) {
legacy_flags::MarkForCompilationPassFlags* flags =
legacy_flags::GetMarkForCompilationPassFlags();
if (flags->tf_xla_clustering_debug) {
dump_graph::DumpGraphToFile("increase_dynamism_for_auto_jit_pass",
**options.graph, options.flib_def);
}
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,57 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_
#define TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
// Increases the amount of "dynamism" representable by XLA clusters by rewriting
// the TensorFlow graph. This pass does the following rewrites:
//
// Slice
// -----
//
// Slice(op, begin, size <must be constant>) =>
// Slice(op, begin, actual_size(op.shape(), size, begin));
// _XlaCompileTimeConstantInputs={2}
//
// where
//
// actual_size(op_shape, size, begin)[i] =
// size[i] == -1 ? (op_shape[i] - size[i])
// : size[i]
//
// This pass, combined with jit/partially_decluster_pass, reduces the number of
// unnecessary cluster recompilations in some common cases. After the rewrite
// shown above jit/partially_decluster_pass extracts the actual_size(...)
// computation to outside the XLA cluster, causing the cluster to be versioned
// only on the actual size of the XlaDynamicSlice. This avoids recompilation
// due to superficial changes that don't affect tensor shapes.
//
// Future Work TODO(b/111210515)
// -----------------------------
//
// In the future we will also translate StridedSlice and Pad a similar way.
class IncreaseDynamismForAutoJitPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_

View File

@ -0,0 +1,285 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/compiler/jit/node_matchers.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace {
using testing::matchers::AssignedDevice;
using testing::matchers::Attr;
using testing::matchers::Const;
using testing::matchers::CtrlDeps;
using testing::matchers::Inputs;
using testing::matchers::Name;
using testing::matchers::NodeWith;
using testing::matchers::Op;
using testing::matchers::Out;
// A fake device used to populate a DeviceSet.
class FakeDevice : public Device {
public:
explicit FakeDevice(const DeviceAttributes& device_attributes)
: Device(nullptr, device_attributes) {}
Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
static std::unique_ptr<Device> Make(const string& name, const string& type) {
DeviceAttributes device_attributes;
device_attributes.set_name(name);
device_attributes.set_device_type(DeviceType(type).type());
return absl::make_unique<FakeDevice>(device_attributes);
}
};
const char* kHostName = "/job:worker/replica:0/task:0/device:CPU:0";
const char* kDeviceName = "/job:worker/replica:0/task:0/device:GPU:0";
Status IncreaseDynamismForAutoJit(const Scope& s,
std::unique_ptr<Graph>* result) {
std::vector<std::unique_ptr<Device>> devices;
devices.push_back(FakeDevice::Make(kDeviceName, DEVICE_GPU));
devices.push_back(FakeDevice::Make(kHostName, DEVICE_CPU));
std::unique_ptr<DeviceSet> device_set(new DeviceSet());
for (auto& device : devices) {
device_set->AddDevice(device.get());
}
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
SessionOptions session_options;
session_options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_global_jit_level(OptimizerOptions::ON_2);
GraphOptimizationPassOptions options;
options.graph = &graph;
options.device_set = device_set.get();
options.session_options = &session_options;
// Scope::ToGraph seems to drop assigned devices, probably because it goes
// through a GraphDef. So explicitly maintain the device assignment.
std::unordered_map<string, string> assigned_device_names;
for (Node* n : s.graph()->nodes()) {
assigned_device_names[n->name()] = n->assigned_device_name();
}
TF_RETURN_IF_ERROR(s.ToGraph(graph.get()));
for (Node* n : graph->nodes()) {
n->set_assigned_device_name(assigned_device_names[n->name()]);
}
IncreaseDynamismForAutoJitPass rewriter;
TF_RETURN_IF_ERROR(rewriter.Run(options));
*result = std::move(graph);
return Status::OK();
}
TEST(SliceToDynamicSliceRewriteTest, Basic) {
Scope root = Scope::NewRootScope()
.ExitOnError()
.WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0");
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
Output size = ops::Const(root.WithOpName("size"), {-1, 500});
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
std::unique_ptr<Graph> result;
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
const int64 zero_64 = 0;
const int32 zero_32 = 0;
const int64 one_64 = 1;
auto m_input = Out(NodeWith(Op("Placeholder"), Name("input")));
auto m_begin_s64 = Out(NodeWith(
Op("Cast"), Inputs(Out(NodeWith(Op("Placeholder"), Name("begin"))))));
auto m_input_shape = Out(NodeWith(Op("Shape"), Inputs(m_input)));
auto m_slice_size_0 = Out(NodeWith(
Op("Sub"), AssignedDevice(kHostName),
Inputs(
Out(NodeWith(Op("Slice"), AssignedDevice(kHostName),
Inputs(m_input_shape, Const(zero_64), Const(one_64)))),
Out(NodeWith(Op("Slice"), AssignedDevice(kHostName),
Inputs(m_begin_s64, Const(zero_64), Const(one_64)))))));
auto m_dynamic_slice_size = Out(NodeWith(
Op("ConcatV2"), AssignedDevice(kHostName),
Inputs(m_slice_size_0, Const(static_cast<int64>(500)), Const(zero_32))));
std::vector<int> compile_time_constant_inputs;
compile_time_constant_inputs.push_back(2);
auto m_dynamic_slice = NodeWith(
Op("Slice"), AssignedDevice(kDeviceName),
Attr(kXlaCompileTimeConstantInputsAttr, compile_time_constant_inputs),
Inputs(m_input, m_begin_s64, m_dynamic_slice_size));
Node* static_shaped_slice = testing::FindNodeByName(
result.get(), "slice/static_shaped_slice/static_shaped_slice");
ASSERT_NE(static_shaped_slice, nullptr);
EXPECT_THAT(static_shaped_slice, m_dynamic_slice);
}
TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) {
Scope root = Scope::NewRootScope()
.ExitOnError()
.WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0");
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
Output size = ops::Const(root.WithOpName("size"), {-1, 500});
Output control_pred = ops::Placeholder(root.WithOpName("control"), DT_BOOL);
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
root.graph()->AddControlEdge(control_pred.node(), slice.node());
std::unique_ptr<Graph> result;
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
Node* static_shaped_slice = testing::FindNodeByName(
result.get(), "slice/static_shaped_slice/static_shaped_slice");
ASSERT_NE(static_shaped_slice, nullptr);
EXPECT_THAT(static_shaped_slice,
NodeWith(Op("Slice"),
CtrlDeps(NodeWith(Op("Placeholder"), Name("control")))));
}
TEST(SliceToDynamicSliceRewriteTest, Int64Indices) {
Scope root = Scope::NewRootScope()
.ExitOnError()
.WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0");
auto to_int64 = [](int v) { return static_cast<int64>(v); };
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
Output size =
ops::Const(root.WithOpName("size"), {to_int64(-1), to_int64(500)});
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
std::unique_ptr<Graph> result;
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("Cast")))));
}
TEST(SliceToDynamicSliceRewriteTest, DontRewriteInvalidSlice) {
Scope root = Scope::NewRootScope()
.ExitOnError()
.WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0");
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
// The shape refiner throws an error if we use a bogus constant value for
// size. So we first use a Placeholder to placate the shape refiner, and
// later replace it with a bogus constant.
Output size_placeholder =
ops::Placeholder(root.WithOpName("size_placeholder"), DT_INT32);
Output slice =
ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder);
Output size = ops::Const(root.WithOpName("size"), {-8, 500});
TF_ASSERT_OK(root.graph()->UpdateEdge(/*new_src=*/size.node(),
/*new_src_index=*/0,
/*dst=*/slice.node(), /*dst_index=*/2));
std::unique_ptr<Graph> result;
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
EXPECT_THAT(result->nodes(),
Not(Contains(NodeWith(Op("Slice"),
Attr(kXlaCompileTimeConstantInputsAttr)))));
}
TEST(SliceToDynamicSliceRewriteTest, DontRewriteUnclusteredSlice) {
Scope root =
Scope::NewRootScope().ExitOnError().WithAssignedDevice(kDeviceName);
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
Output size = ops::Const(root.WithOpName("size"), {-1, 500});
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
std::unique_ptr<Graph> result;
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
EXPECT_THAT(result->nodes(),
Not(Contains(NodeWith(Op("Slice"),
Attr(kXlaCompileTimeConstantInputsAttr)))));
}
TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithNonConstSize) {
Scope root = Scope::NewRootScope()
.ExitOnError()
.WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0");
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
Output size = ops::Placeholder(root.WithOpName("size"), DT_INT64);
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
std::unique_ptr<Graph> result;
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
EXPECT_THAT(result->nodes(),
Not(Contains(NodeWith(Op("Slice"),
Attr(kXlaCompileTimeConstantInputsAttr)))));
}
TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) {
Scope root = Scope::NewRootScope()
.ExitOnError()
.WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0");
auto to_int64 = [](int v) { return static_cast<int64>(v); };
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
// The C++ node bindings immediately error out when we try construct a bogus
// slice so we first use a placeholder to construct the Slice and then replace
// the input.
Output size_placeholder = ops::Placeholder(root.WithOpName("size"), DT_INT64);
Output slice =
ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder);
Output size =
ops::Const(root.WithOpName("size"), {{to_int64(-1)}, {to_int64(500)}});
TF_ASSERT_OK(root.graph()->UpdateEdge(size.node(), 0, slice.node(), 2));
std::unique_ptr<Graph> result;
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
EXPECT_THAT(result->nodes(),
Not(Contains(NodeWith(Op("Slice"),
Attr(kXlaCompileTimeConstantInputsAttr)))));
}
} // namespace
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
@ -44,17 +45,20 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
IncreaseDynamismForAutoJitPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
PartiallyDeclusterPass);
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
// also need to run it after the graph been rewritten to have _Send nodes added
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
// name, and encapsulation might remove that node from the graph.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
EncapsulateSubgraphsPass);
// Must run after EncapsulateSubgraphsPass.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50,
BuildXlaOpsPass);
} // namespace tensorflow

View File

@ -204,11 +204,12 @@ struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
}
return false;
}
if (!AreAttrValuesEqual(it->second, attr_kv_pair.second)) {
if (attr_kv_pair.second &&
!AreAttrValuesEqual(it->second, *attr_kv_pair.second)) {
if (listener->IsInterested()) {
*listener << "attribute named " << attr_kv_pair.first
<< " does not match value; expected: \""
<< SummarizeAttrValue(attr_kv_pair.second)
<< SummarizeAttrValue(*attr_kv_pair.second)
<< "\", found: \"" << SummarizeAttrValue(it->second)
<< "\"";
}
@ -278,12 +279,14 @@ struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
if (!attrs.empty()) {
printed_something = true;
std::vector<string> attrs_str;
absl::c_transform(attrs, std::back_inserter(attrs_str),
[](const std::pair<string, AttrValue>& attr_kv_pair) {
return absl::StrCat(
attr_kv_pair.first, "->",
SummarizeAttrValue(attr_kv_pair.second));
});
absl::c_transform(
attrs, std::back_inserter(attrs_str),
[](const std::pair<string, absl::optional<AttrValue>>& attr_kv_pair) {
return absl::StrCat(attr_kv_pair.first, "->",
attr_kv_pair.second
? SummarizeAttrValue(*attr_kv_pair.second)
: "*");
});
*os << " and attr values matching [" << absl::StrJoin(attrs_str, ", ")
<< "]";
}
@ -327,7 +330,7 @@ struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
absl::optional<std::vector<::testing::Matcher<OutEdge>>> input_matchers;
absl::optional<::testing::Matcher<absl::Span<const Node* const>>>
control_dep_set;
std::map<string, AttrValue> attrs;
std::map<string, absl::optional<AttrValue>> attrs;
};
// Matches a dst and dst_output on an input edge. Today we only use this with
@ -472,12 +475,28 @@ std::pair<string, AttrValue> impl::AttrLiteralHelper(
return {bool_attr.first, attr_value};
}
std::pair<string, AttrValue> impl::AttrLiteralHelper(
const std::pair<string, absl::Span<const int>>& int_list_attr) {
AttrValue attr_value;
AttrValue::ListValue* list = attr_value.mutable_list();
for (int i : int_list_attr.second) {
list->add_i(i);
}
return {int_list_attr.first, attr_value};
}
impl::NodeMatcherProperties impl::Attr(std::pair<string, AttrValue> attr) {
impl::NodeMatcherProperties props;
props.set_attr(std::move(attr));
return props;
}
impl::NodeMatcherProperties impl::Attr(string name) {
impl::NodeMatcherProperties props;
props.set_attr({std::move(name), absl::nullopt});
return props;
}
NodeMatcherProperties ConstantValue(
const ::tensorflow::Input::Initializer& val) {
TF_CHECK_OK(val.status);
@ -486,9 +505,9 @@ NodeMatcherProperties ConstantValue(
return props;
}
::testing::Matcher<const Node*> Const(
::testing::Matcher<impl::OutEdge> Const(
const ::tensorflow::Input::Initializer& val) {
return NodeWith(ConstantValue(val));
return Out(NodeWith(ConstantValue(val)));
}
::testing::Matcher<impl::OutEdge> Out(
int oidx, ::testing::Matcher<const Node*> node_matcher) {

View File

@ -84,7 +84,7 @@ class NodeMatcherProperties {
public:
using NodeSeqMatcher = std::vector<::testing::Matcher<const Node*>>;
using InputSeqMatcher = std::vector<::testing::Matcher<OutEdge>>;
using AttrKeyValuePair = std::pair<string, AttrValue>;
using AttrKeyValuePair = std::pair<string, absl::optional<AttrValue>>;
const absl::optional<string>& name() const { return name_; }
const absl::optional<string>& op() const { return op_; }
@ -163,9 +163,13 @@ impl::NodeMatcherProperties CtrlDeps(
absl::Span<const ::testing::Matcher<const Node*>> control_deps);
impl::NodeMatcherProperties Attr(std::pair<string, AttrValue> attrs);
impl::NodeMatcherProperties Attr(string name);
std::pair<string, AttrValue> AttrLiteralHelper(
const std::pair<string, bool>& bool_attr);
std::pair<string, AttrValue> AttrLiteralHelper(
const std::pair<string, absl::Span<const int>>& int_list_attr);
} // namespace impl
// -----------------------------------------------------------------------------
@ -187,6 +191,10 @@ impl::NodeMatcherProperties Attr(const string& name, ValueTy value) {
return impl::Attr({impl::AttrLiteralHelper({name, value})});
}
inline impl::NodeMatcherProperties Attr(const string& name) {
return impl::Attr(name);
}
// Matches a node with inputs `inputs`.
//
// `inputs` are ordered; `inputs`[i] must match input i.
@ -200,7 +208,8 @@ impl::NodeMatcherProperties Inputs(Ts... inputs) {
::testing::Matcher<const Node*> node);
// Matches the first output of a node that matches `node`.
::testing::Matcher<impl::OutEdge> Out(::testing::Matcher<const Node*> node) {
inline ::testing::Matcher<impl::OutEdge> Out(
::testing::Matcher<const Node*> node) {
return Out(0, node);
}
@ -224,7 +233,7 @@ template <typename... Ts>
return impl::NodeWith(array);
}
::testing::Matcher<const Node*> Const(
::testing::Matcher<impl::OutEdge> Const(
const ::tensorflow::Input::Initializer& val);
} // namespace matchers

View File

@ -104,6 +104,11 @@ NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) {
return *this;
}
NodeBuilder& NodeBuilder::XlaCluster(StringPiece xla_cluster) {
def_builder_.Attr("_XlaCluster", xla_cluster);
return *this;
}
Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
// In case of error, set *created_node to nullptr.
if (created_node != nullptr) *created_node = nullptr;

View File

@ -103,6 +103,9 @@ class NodeBuilder {
// Sets the device name in the "assigned device" field in tensorflow::Node.
NodeBuilder& AssignedDevice(StringPiece device);
// Sets the _XlaCluster attribute in created node to `xla_cluster`.
NodeBuilder& XlaCluster(StringPiece xla_cluster);
// Set the value of an attr. attr_name must match the name of one of
// attrs defined by the Op, and value must have the corresponding type
// (see SetAttrValue() in ../framework/attr_value_util.h for legal

View File

@ -480,4 +480,16 @@ std::vector<string> DeviceNameUtils::GetLocalNamesForDeviceMappings(
}
}
/*static*/ Status DeviceNameUtils::DeviceNameToCpuDeviceName(
const string& device_name, string* host_device_name) {
DeviceNameUtils::ParsedName device;
if (!DeviceNameUtils::ParseFullName(device_name, &device)) {
return errors::Internal("Could not parse device name ", device_name);
}
device.type = "CPU";
device.id = 0;
*host_device_name = DeviceNameUtils::ParsedNameToString(device);
return Status::OK();
}
} // namespace tensorflow

View File

@ -169,6 +169,11 @@ class DeviceNameUtils {
// mapping.
static std::vector<string> GetLocalNamesForDeviceMappings(
const ParsedName& pn);
// Returns name of the CPU:0 device on the same host as the device
// `device_name`.
static Status DeviceNameToCpuDeviceName(const string& device_name,
string* host_device_name);
};
} // namespace tensorflow