Introduce a pass to increase the amount of dynamism supported by an XLA cluster
Increases the amount of dynamism representable by XLA clusters by rewriting the TensorFlow graph. See the header for a description. This pass, combined with jit/partially_decluster_pass, reduces the number of unnecessary cluster recompilations in some common cases. The CL is organized as follows: - cc/framework/scope* and core/graph/node_builder are modified so that new nodes can now be automatically put in an XLA cluster using Scope::WithXlaCluster. - The pass is implemented in jit/increase_dynamism_for_auto_jit_pass. - In jit/jit_compilation_pass_registration The new pass is registered to run between MarkForCompilationPass and PartiallyDeclusterPass. PiperOrigin-RevId: 218907734
This commit is contained in:
parent
33d6328d16
commit
2c164ed32f
@ -95,6 +95,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(other.impl()->device_),
|
||||
assigned_device_(other.impl()->assigned_device_),
|
||||
xla_cluster_(other.impl()->xla_cluster_),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
@ -112,6 +113,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(other.impl()->device_),
|
||||
assigned_device_(other.impl()->assigned_device_),
|
||||
xla_cluster_(other.impl()->xla_cluster_),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
@ -135,6 +137,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(other.impl()->device_),
|
||||
assigned_device_(other.impl()->assigned_device_),
|
||||
xla_cluster_(other.impl()->xla_cluster_),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
@ -167,6 +170,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(other.impl()->device_),
|
||||
assigned_device_(other.impl()->assigned_device_),
|
||||
xla_cluster_(other.impl()->xla_cluster_),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
@ -183,6 +187,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(other.impl()->device_),
|
||||
assigned_device_(other.impl()->assigned_device_),
|
||||
xla_cluster_(other.impl()->xla_cluster_),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
@ -200,6 +205,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
|
||||
kernel_label_(kernel_label),
|
||||
device_(other.impl()->device_),
|
||||
assigned_device_(other.impl()->assigned_device_),
|
||||
xla_cluster_(other.impl()->xla_cluster_),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
@ -217,6 +223,7 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(other.impl()->device_),
|
||||
assigned_device_(other.impl()->assigned_device_),
|
||||
xla_cluster_(other.impl()->xla_cluster_),
|
||||
colocation_constraints_(
|
||||
clear_colocations
|
||||
? std::unordered_set<string>()
|
||||
@ -237,6 +244,25 @@ Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice,
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(other.impl()->device_),
|
||||
assigned_device_(assigned_device),
|
||||
xla_cluster_(other.impl()->xla_cluster_),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
Scope::Impl::Impl(const Scope& other, Tags::XlaCluster,
|
||||
const string& xla_cluster)
|
||||
: graph_(other.impl()->graph_),
|
||||
status_(other.impl()->status_),
|
||||
name_map_(other.impl()->name_map_),
|
||||
refiner_(other.impl()->refiner_),
|
||||
scope_used_(other.impl()->scope_used_),
|
||||
control_deps_(other.impl()->control_deps_),
|
||||
name_(other.impl()->name_),
|
||||
op_name_(other.impl()->op_name_),
|
||||
exit_on_error_(other.impl()->exit_on_error_),
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(other.impl()->device_),
|
||||
assigned_device_(other.impl()->assigned_device_),
|
||||
xla_cluster_(xla_cluster),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
@ -326,6 +352,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const {
|
||||
if (!impl()->assigned_device_.empty()) {
|
||||
builder->AssignedDevice(impl()->assigned_device_);
|
||||
}
|
||||
if (!impl()->xla_cluster_.empty()) {
|
||||
builder->XlaCluster(impl()->xla_cluster_);
|
||||
}
|
||||
}
|
||||
|
||||
string Scope::Impl::GetUniqueName(const string& prefix,
|
||||
@ -388,7 +417,7 @@ Scope Scope::NewSubScope(const string& child_scope_name) const {
|
||||
false /* copy_names */));
|
||||
}
|
||||
|
||||
Scope Scope::WithOpName(const string& op_name) const {
|
||||
Scope Scope::WithOpNameImpl(const string& op_name) const {
|
||||
if (impl()->single_use_scope()) {
|
||||
UpdateStatus(errors::InvalidArgument("Cannot set op name ", op_name,
|
||||
" on this scope"));
|
||||
@ -425,6 +454,10 @@ Scope Scope::WithAssignedDevice(const string& assigned_device) const {
|
||||
return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device));
|
||||
}
|
||||
|
||||
Scope Scope::WithXlaCluster(const string& xla_cluster) const {
|
||||
return Scope(new Impl(*this, Impl::Tags::XlaCluster(), xla_cluster));
|
||||
}
|
||||
|
||||
Scope Scope::ColocateWith(const Operation& op) const {
|
||||
return Scope(new Impl(*this, Impl::Tags::Colocate(), op,
|
||||
/* clear_colocations */ false));
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
@ -69,8 +70,9 @@ struct CompositeOpScopes;
|
||||
/// // W will be named "linear/W"
|
||||
/// auto W = Variable(linear.WithOpName("W"),
|
||||
/// {2, 2}, DT_FLOAT);
|
||||
/// // b will be named "linear/b"
|
||||
/// auto b = Variable(linear.WithOpName("b"),
|
||||
/// // b will be named "linear/b_3"
|
||||
/// int idx = 3;
|
||||
/// auto b = Variable(linear.WithOpName("b_", idx),
|
||||
/// {2}, DT_FLOAT);
|
||||
/// auto x = Const(linear, {...}); // name: "linear/Const"
|
||||
/// auto m = MatMul(linear, x, W); // name: "linear/MatMul"
|
||||
@ -113,8 +115,11 @@ class Scope {
|
||||
Scope NewSubScope(const string& child_scope_name) const;
|
||||
|
||||
/// Return a new scope. All ops created within the returned scope will have
|
||||
/// names of the form `name/op_name[_suffix]`.
|
||||
Scope WithOpName(const string& op_name) const;
|
||||
/// names of the form `name/StrCat(fragments...)[_suffix]`
|
||||
template <typename... Ty>
|
||||
Scope WithOpName(Ty... fragments) const {
|
||||
return WithOpNameImpl(absl::StrCat(fragments...));
|
||||
}
|
||||
|
||||
/// Return a new scope. All ops created within the returned scope will have as
|
||||
/// control dependencies the union of operations in the control_deps vector
|
||||
@ -137,6 +142,10 @@ class Scope {
|
||||
/// their assigned device set to `assigned_device`.
|
||||
Scope WithAssignedDevice(const string& assigned_device) const;
|
||||
|
||||
/// Returns a new scope. All ops created within the returned scope will have
|
||||
/// their _XlaCluster attribute set to `xla_cluster`.
|
||||
Scope WithXlaCluster(const string& xla_cluster) const;
|
||||
|
||||
/// Return a new scope. All ops created within the returned scope will be
|
||||
/// co-located on the device where op is placed.
|
||||
/// NOTE: This function is intended to be use internal libraries only for
|
||||
@ -227,6 +236,8 @@ class Scope {
|
||||
// END_SKIP_DOXYGEN
|
||||
|
||||
private:
|
||||
Scope WithOpNameImpl(const string& op_name) const;
|
||||
|
||||
friend class InternalScope;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
explicit Scope(Impl*);
|
||||
|
@ -61,6 +61,7 @@ class Scope::Impl {
|
||||
enum class KernelLabel;
|
||||
enum class Colocate;
|
||||
enum class AssignedDevice;
|
||||
enum class XlaCluster;
|
||||
};
|
||||
|
||||
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner,
|
||||
@ -78,6 +79,7 @@ class Scope::Impl {
|
||||
Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op,
|
||||
bool clear_colocations);
|
||||
Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device);
|
||||
Impl(const Scope& other, Tags::XlaCluster, const string& xla_cluster);
|
||||
|
||||
std::unordered_set<string> GetColocationConstraints(
|
||||
const Operation& colocate_with_op) const;
|
||||
@ -112,6 +114,7 @@ class Scope::Impl {
|
||||
const string kernel_label_ = "";
|
||||
const string device_ = "";
|
||||
const string assigned_device_ = "";
|
||||
const string xla_cluster_ = "";
|
||||
const std::unordered_set<string> colocation_constraints_;
|
||||
|
||||
// If true, Scope::DoShapeInference() always returns Status:OK().
|
||||
|
@ -448,6 +448,7 @@ cc_library(
|
||||
"encapsulate_subgraphs_pass.cc",
|
||||
"encapsulate_xla_computations_pass.cc",
|
||||
"extract_outside_compilation_pass.cc",
|
||||
"increase_dynamism_for_auto_jit_pass.cc",
|
||||
"mark_for_compilation_pass.cc",
|
||||
"mark_for_compilation_pass_test_helper.cc",
|
||||
"partially_decluster_pass.cc",
|
||||
@ -458,6 +459,7 @@ cc_library(
|
||||
"encapsulate_subgraphs_pass.h",
|
||||
"encapsulate_xla_computations_pass.h",
|
||||
"extract_outside_compilation_pass.h",
|
||||
"increase_dynamism_for_auto_jit_pass.h",
|
||||
"mark_for_compilation_pass.h",
|
||||
"mark_for_compilation_pass_test_helper.h",
|
||||
"partially_decluster_pass.h",
|
||||
@ -480,6 +482,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
|
||||
"//tensorflow/compiler/tf2xla/cc:xla_ops",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -493,6 +496,7 @@ cc_library(
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
@ -576,6 +580,7 @@ tf_cc_test(
|
||||
"encapsulate_subgraphs_pass_test.cc",
|
||||
"encapsulate_xla_computations_pass_test.cc",
|
||||
"extract_outside_compilation_pass_test.cc",
|
||||
"increase_dynamism_for_auto_jit_pass_test.cc",
|
||||
"mark_for_compilation_pass_test.cc",
|
||||
"partially_decluster_pass_test.cc",
|
||||
],
|
||||
|
331
tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc
Normal file
331
tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc
Normal file
@ -0,0 +1,331 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "tensorflow/cc/framework/scope_internal.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/math_ops.h"
|
||||
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
Status GetTensorFromConstOp(Node* n, Tensor* out_tensor) {
|
||||
TF_RET_CHECK(n->type_string() == "Const");
|
||||
const TensorProto* proto = nullptr;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto));
|
||||
Tensor tensor(proto->dtype());
|
||||
TF_RET_CHECK(tensor.FromProto(*proto));
|
||||
*out_tensor = std::move(tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
struct SliceInputs {
|
||||
Output slice_op;
|
||||
Output input;
|
||||
Output begin;
|
||||
Output size;
|
||||
|
||||
// The size of the TF slice operation as a std::vector. We can always compute
|
||||
// this because we only manipulate slices with a Const size.
|
||||
std::vector<int64> size_as_vector;
|
||||
};
|
||||
|
||||
std::vector<int64> IntTensorAsVector(const Tensor& t) {
|
||||
DCHECK(t.dtype() == DT_INT32 || t.dtype() == DT_INT64);
|
||||
std::vector<int64> result;
|
||||
result.reserve(t.NumElements());
|
||||
for (int i = 0; i < t.NumElements(); i++) {
|
||||
int64 element = t.dtype() == DT_INT32
|
||||
? static_cast<int64>(t.flat<int32>()(i))
|
||||
: t.flat<int64>()(i);
|
||||
result.push_back(element);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Packages up the inputs to a Slice operation into an instance of
|
||||
// `SliceInputs`.
|
||||
Status GetSliceInputs(Node* slice, SliceInputs* slice_inputs) {
|
||||
const int kSliceInputIndex = 0;
|
||||
const int kSliceBeginIndex = 1;
|
||||
const int kSliceSizeIndex = 2;
|
||||
|
||||
const Edge* slice_input_edge;
|
||||
TF_RETURN_IF_ERROR(slice->input_edge(kSliceInputIndex, &slice_input_edge));
|
||||
const Edge* slice_size_edge;
|
||||
TF_RETURN_IF_ERROR(slice->input_edge(kSliceSizeIndex, &slice_size_edge));
|
||||
const Edge* slice_begin_edge;
|
||||
TF_RETURN_IF_ERROR(slice->input_edge(kSliceBeginIndex, &slice_begin_edge));
|
||||
slice_inputs->input =
|
||||
Output(slice_input_edge->src(), slice_input_edge->src_output());
|
||||
slice_inputs->begin =
|
||||
Output(slice_begin_edge->src(), slice_begin_edge->src_output());
|
||||
slice_inputs->size =
|
||||
Output(slice_size_edge->src(), slice_size_edge->src_output());
|
||||
|
||||
Tensor tf_slice_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetTensorFromConstOp(slice_inputs->size.node(), &tf_slice_size));
|
||||
|
||||
if (tf_slice_size.dims() != 1) {
|
||||
return errors::Internal("Expected vector for the slice size input.");
|
||||
}
|
||||
|
||||
slice_inputs->size_as_vector = IntTensorAsVector(tf_slice_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Casts `x` to a DT_INT64 if it isn't one already.
|
||||
Output MakeInt64(const Scope& host_scope, absl::string_view name,
|
||||
const Output& x) {
|
||||
return x.type() == DT_INT64
|
||||
? x
|
||||
: ops::Cast(host_scope.WithOpName(name, "_s64"), x, DT_INT64);
|
||||
}
|
||||
|
||||
// Returns `slice_inputs` with the index and size inputs cast to DT_INT64.
|
||||
SliceInputs MakeSliceIndexAndSizeInt64(const Scope& host_scope,
|
||||
const SliceInputs& slice_inputs) {
|
||||
SliceInputs result;
|
||||
result.input = slice_inputs.input;
|
||||
result.begin = MakeInt64(host_scope, "begin", slice_inputs.begin);
|
||||
result.size = MakeInt64(host_scope, "size", slice_inputs.size);
|
||||
result.size_as_vector = slice_inputs.size_as_vector;
|
||||
return result;
|
||||
}
|
||||
|
||||
// This class caches emitted constants to avoid creating multiple nodes for the
|
||||
// same constant value. This helps make the generated GraphDef more readable.
|
||||
class ConstantCache {
|
||||
public:
|
||||
explicit ConstantCache(const Scope& s) : scope_(s) {}
|
||||
|
||||
Output Get1DHostConstant(int64 constant) {
|
||||
auto it = cache_.find(constant);
|
||||
if (it == cache_.end()) {
|
||||
Output new_const =
|
||||
ops::Const(scope_.WithOpName("const_", constant), {constant});
|
||||
it = cache_.insert({constant, new_const}).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
private:
|
||||
Scope scope_;
|
||||
std::unordered_map<int, Output> cache_;
|
||||
};
|
||||
|
||||
// Returns a node computing the size of the Slice op with inputs `slice_inputs`.
|
||||
Status ComputeSliceSize(const Scope& host_scope,
|
||||
const SliceInputs& slice_inputs, Output* size) {
|
||||
// If slice_size[i] >= 0 then slice_size[i] = slice_size[i].
|
||||
//
|
||||
// If slice_size[i] == -1 then slice_size[i] = input_size[i] -
|
||||
// begin[i].
|
||||
//
|
||||
// If slice_size[i] < -1 then executing the slice will throw an error, and we
|
||||
// don't do anything here. We've already filtered these cases out in
|
||||
// IsRewritableSlice.
|
||||
|
||||
if (absl::c_all_of(slice_inputs.size_as_vector,
|
||||
[](int64 i) { return i >= 0; })) {
|
||||
*size = slice_inputs.size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Output input_shape =
|
||||
ops::Shape(host_scope.WithOpName("input_shape"), slice_inputs.input,
|
||||
ops::Shape::OutType(DT_INT64));
|
||||
|
||||
ConstantCache constant_pool(host_scope);
|
||||
|
||||
std::vector<Output> slice_size;
|
||||
for (int i = 0; i < slice_inputs.size_as_vector.size(); i++) {
|
||||
if (slice_inputs.size_as_vector[i] >= 0) {
|
||||
slice_size.push_back(
|
||||
constant_pool.Get1DHostConstant(slice_inputs.size_as_vector[i]));
|
||||
continue;
|
||||
}
|
||||
|
||||
DCHECK_EQ(slice_inputs.size_as_vector[i], -1);
|
||||
|
||||
Output begin_i = ops::Slice(
|
||||
host_scope.WithOpName("begin_", i), slice_inputs.begin,
|
||||
constant_pool.Get1DHostConstant(i), constant_pool.Get1DHostConstant(1));
|
||||
|
||||
Output input_shape_i = ops::Slice(
|
||||
host_scope.WithOpName("input_shape_", i), input_shape,
|
||||
constant_pool.Get1DHostConstant(i), constant_pool.Get1DHostConstant(1));
|
||||
|
||||
slice_size.push_back(ops::Sub(host_scope.WithOpName("slice_size_", i),
|
||||
input_shape_i, begin_i));
|
||||
DCHECK_EQ(slice_size.back().type(), DT_INT64);
|
||||
}
|
||||
|
||||
*size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size,
|
||||
ops::Const(host_scope.WithOpName("concat_axis"), 0));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Terminology: "static sized" slice is a slice with the
|
||||
// _XlaCompileTimeConstantInputs attribute set to {2}. The output shape of
|
||||
// these slices can be solely determined by their "size" input.
|
||||
Status ConvertTensorFlowSliceToStaticShapedSlice(
|
||||
Graph* g, Node* slice, const SliceInputs& slice_inputs,
|
||||
absl::string_view cluster_name, Node** result) {
|
||||
string host_name;
|
||||
TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
|
||||
slice->assigned_device_name(), &host_name));
|
||||
|
||||
Status status;
|
||||
Scope main_scope =
|
||||
NewInternalScope(g, &status, /*refiner=*/nullptr)
|
||||
.WithXlaCluster(string(cluster_name))
|
||||
.NewSubScope(absl::StrCat(slice->name(), "/static_shaped_slice"));
|
||||
Scope host_scope = main_scope.WithAssignedDevice(host_name);
|
||||
|
||||
SliceInputs slice_inputs_int64 =
|
||||
MakeSliceIndexAndSizeInt64(host_scope, slice_inputs);
|
||||
|
||||
Output slice_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ComputeSliceSize(host_scope, slice_inputs_int64, &slice_size));
|
||||
|
||||
*result =
|
||||
ops::Slice(main_scope.WithAssignedDevice(slice->assigned_device_name())
|
||||
.WithOpName("static_shaped_slice"),
|
||||
slice_inputs_int64.input, slice_inputs_int64.begin, slice_size)
|
||||
.node();
|
||||
std::vector<int> compile_time_const_inputs;
|
||||
compile_time_const_inputs.push_back(2);
|
||||
(*result)->AddAttr(kXlaCompileTimeConstantInputsAttr,
|
||||
compile_time_const_inputs);
|
||||
return status;
|
||||
}
|
||||
|
||||
void ReplaceTensorFlowSliceWithStaticShapedSlice(Graph* g, Node* slice,
|
||||
Node* static_shaped_slice) {
|
||||
absl::InlinedVector<const Edge*, 6> edges_to_remove;
|
||||
std::vector<const Edge*> slice_out_edges;
|
||||
absl::c_copy(slice->out_edges(), std::back_inserter(slice_out_edges));
|
||||
for (const Edge* e : slice_out_edges) {
|
||||
DCHECK(e->src_output() == 0 || e->src_output() == Graph::kControlSlot);
|
||||
|
||||
int src_output = e->src_output();
|
||||
int dst_input = e->dst_input();
|
||||
Node* dst = e->dst();
|
||||
g->RemoveEdge(e);
|
||||
g->AddEdge(static_shaped_slice, src_output, dst, dst_input);
|
||||
}
|
||||
|
||||
for (const Edge* e : slice->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
g->AddControlEdge(e->src(), static_shaped_slice);
|
||||
}
|
||||
}
|
||||
|
||||
g->RemoveNode(slice);
|
||||
}
|
||||
|
||||
Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs,
|
||||
absl::string_view cluster_name) {
|
||||
VLOG(3) << "Rewriting slice " << slice->name()
|
||||
<< " to a \"static shaped\" Slice";
|
||||
Node* static_shaped_slice;
|
||||
TF_RETURN_IF_ERROR(ConvertTensorFlowSliceToStaticShapedSlice(
|
||||
g, slice, slice_inputs, cluster_name, &static_shaped_slice));
|
||||
ReplaceTensorFlowSliceWithStaticShapedSlice(g, slice, static_shaped_slice);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns true if `n` is a slice we can rewrite to have a static shape
|
||||
// (i.e. have the output shape only depend on the "size" input). Fills in
|
||||
// `slice_inputs` in the process.
|
||||
bool IsRewritableSlice(Node* n, SliceInputs* slice_inputs) {
|
||||
if (n->type_string() != "Slice") {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!GetXlaClusterForNode(*n).has_value()) {
|
||||
// There is no need to change slice ops outside XLA clusters.
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!GetSliceInputs(n, slice_inputs).ok()) {
|
||||
// Could not parse slice inputs. E.g. the sizes input was not a constant.
|
||||
return false;
|
||||
}
|
||||
|
||||
// If slice_size[i] < -1 for any i then executing the slice will throw an
|
||||
// error, and we don't do anything here.
|
||||
return absl::c_all_of(slice_inputs->size_as_vector,
|
||||
[](int64 size_i) { return size_i >= -1; });
|
||||
}
|
||||
|
||||
Status FindAndRewriteSlices(Graph* g, bool* changed) {
|
||||
std::vector<std::pair<Node*, SliceInputs>> slices_to_rewrite;
|
||||
for (Node* n : g->nodes()) {
|
||||
SliceInputs slice_inputs;
|
||||
if (IsRewritableSlice(n, &slice_inputs)) {
|
||||
slices_to_rewrite.push_back({n, std::move(slice_inputs)});
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& pair : slices_to_rewrite) {
|
||||
TF_RETURN_IF_ERROR(RewriteSlice(g, pair.first, pair.second,
|
||||
*GetXlaClusterForNode(*pair.first)));
|
||||
}
|
||||
|
||||
if (!slices_to_rewrite.empty()) {
|
||||
// We've added constants to the graph; hook them up to _SOURCE.
|
||||
FixupSourceAndSinkEdges(g);
|
||||
}
|
||||
|
||||
*changed = !slices_to_rewrite.empty();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status IncreaseDynamismForAutoJitPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
bool changed;
|
||||
TF_RETURN_IF_ERROR(FindAndRewriteSlices(options.graph->get(), &changed));
|
||||
if (changed) {
|
||||
legacy_flags::MarkForCompilationPassFlags* flags =
|
||||
legacy_flags::GetMarkForCompilationPassFlags();
|
||||
if (flags->tf_xla_clustering_debug) {
|
||||
dump_graph::DumpGraphToFile("increase_dynamism_for_auto_jit_pass",
|
||||
**options.graph, options.flib_def);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,57 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Increases the amount of "dynamism" representable by XLA clusters by rewriting
|
||||
// the TensorFlow graph. This pass does the following rewrites:
|
||||
//
|
||||
// Slice
|
||||
// -----
|
||||
//
|
||||
// Slice(op, begin, size <must be constant>) =>
|
||||
// Slice(op, begin, actual_size(op.shape(), size, begin));
|
||||
// _XlaCompileTimeConstantInputs={2}
|
||||
//
|
||||
// where
|
||||
//
|
||||
// actual_size(op_shape, size, begin)[i] =
|
||||
// size[i] == -1 ? (op_shape[i] - size[i])
|
||||
// : size[i]
|
||||
//
|
||||
// This pass, combined with jit/partially_decluster_pass, reduces the number of
|
||||
// unnecessary cluster recompilations in some common cases. After the rewrite
|
||||
// shown above jit/partially_decluster_pass extracts the actual_size(...)
|
||||
// computation to outside the XLA cluster, causing the cluster to be versioned
|
||||
// only on the actual size of the XlaDynamicSlice. This avoids recompilation
|
||||
// due to superficial changes that don't affect tensor shapes.
|
||||
//
|
||||
// Future Work TODO(b/111210515)
|
||||
// -----------------------------
|
||||
//
|
||||
// In the future we will also translate StridedSlice and Pad a similar way.
|
||||
class IncreaseDynamismForAutoJitPass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_
|
@ -0,0 +1,285 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/compiler/jit/node_matchers.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using testing::matchers::AssignedDevice;
|
||||
using testing::matchers::Attr;
|
||||
using testing::matchers::Const;
|
||||
using testing::matchers::CtrlDeps;
|
||||
using testing::matchers::Inputs;
|
||||
using testing::matchers::Name;
|
||||
using testing::matchers::NodeWith;
|
||||
using testing::matchers::Op;
|
||||
using testing::matchers::Out;
|
||||
|
||||
// A fake device used to populate a DeviceSet.
|
||||
class FakeDevice : public Device {
|
||||
public:
|
||||
explicit FakeDevice(const DeviceAttributes& device_attributes)
|
||||
: Device(nullptr, device_attributes) {}
|
||||
|
||||
Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
|
||||
|
||||
Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
|
||||
|
||||
static std::unique_ptr<Device> Make(const string& name, const string& type) {
|
||||
DeviceAttributes device_attributes;
|
||||
device_attributes.set_name(name);
|
||||
device_attributes.set_device_type(DeviceType(type).type());
|
||||
return absl::make_unique<FakeDevice>(device_attributes);
|
||||
}
|
||||
};
|
||||
|
||||
const char* kHostName = "/job:worker/replica:0/task:0/device:CPU:0";
|
||||
const char* kDeviceName = "/job:worker/replica:0/task:0/device:GPU:0";
|
||||
|
||||
Status IncreaseDynamismForAutoJit(const Scope& s,
|
||||
std::unique_ptr<Graph>* result) {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
devices.push_back(FakeDevice::Make(kDeviceName, DEVICE_GPU));
|
||||
devices.push_back(FakeDevice::Make(kHostName, DEVICE_CPU));
|
||||
|
||||
std::unique_ptr<DeviceSet> device_set(new DeviceSet());
|
||||
for (auto& device : devices) {
|
||||
device_set->AddDevice(device.get());
|
||||
}
|
||||
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
SessionOptions session_options;
|
||||
session_options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_global_jit_level(OptimizerOptions::ON_2);
|
||||
GraphOptimizationPassOptions options;
|
||||
options.graph = &graph;
|
||||
options.device_set = device_set.get();
|
||||
options.session_options = &session_options;
|
||||
|
||||
// Scope::ToGraph seems to drop assigned devices, probably because it goes
|
||||
// through a GraphDef. So explicitly maintain the device assignment.
|
||||
std::unordered_map<string, string> assigned_device_names;
|
||||
for (Node* n : s.graph()->nodes()) {
|
||||
assigned_device_names[n->name()] = n->assigned_device_name();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(s.ToGraph(graph.get()));
|
||||
for (Node* n : graph->nodes()) {
|
||||
n->set_assigned_device_name(assigned_device_names[n->name()]);
|
||||
}
|
||||
|
||||
IncreaseDynamismForAutoJitPass rewriter;
|
||||
TF_RETURN_IF_ERROR(rewriter.Run(options));
|
||||
*result = std::move(graph);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST(SliceToDynamicSliceRewriteTest, Basic) {
|
||||
Scope root = Scope::NewRootScope()
|
||||
.ExitOnError()
|
||||
.WithAssignedDevice(kDeviceName)
|
||||
.WithXlaCluster("cluster_0");
|
||||
|
||||
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
||||
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
|
||||
Output size = ops::Const(root.WithOpName("size"), {-1, 500});
|
||||
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
|
||||
|
||||
std::unique_ptr<Graph> result;
|
||||
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
|
||||
|
||||
const int64 zero_64 = 0;
|
||||
const int32 zero_32 = 0;
|
||||
const int64 one_64 = 1;
|
||||
|
||||
auto m_input = Out(NodeWith(Op("Placeholder"), Name("input")));
|
||||
auto m_begin_s64 = Out(NodeWith(
|
||||
Op("Cast"), Inputs(Out(NodeWith(Op("Placeholder"), Name("begin"))))));
|
||||
auto m_input_shape = Out(NodeWith(Op("Shape"), Inputs(m_input)));
|
||||
auto m_slice_size_0 = Out(NodeWith(
|
||||
Op("Sub"), AssignedDevice(kHostName),
|
||||
Inputs(
|
||||
Out(NodeWith(Op("Slice"), AssignedDevice(kHostName),
|
||||
Inputs(m_input_shape, Const(zero_64), Const(one_64)))),
|
||||
Out(NodeWith(Op("Slice"), AssignedDevice(kHostName),
|
||||
Inputs(m_begin_s64, Const(zero_64), Const(one_64)))))));
|
||||
auto m_dynamic_slice_size = Out(NodeWith(
|
||||
Op("ConcatV2"), AssignedDevice(kHostName),
|
||||
Inputs(m_slice_size_0, Const(static_cast<int64>(500)), Const(zero_32))));
|
||||
|
||||
std::vector<int> compile_time_constant_inputs;
|
||||
compile_time_constant_inputs.push_back(2);
|
||||
auto m_dynamic_slice = NodeWith(
|
||||
Op("Slice"), AssignedDevice(kDeviceName),
|
||||
Attr(kXlaCompileTimeConstantInputsAttr, compile_time_constant_inputs),
|
||||
Inputs(m_input, m_begin_s64, m_dynamic_slice_size));
|
||||
|
||||
Node* static_shaped_slice = testing::FindNodeByName(
|
||||
result.get(), "slice/static_shaped_slice/static_shaped_slice");
|
||||
ASSERT_NE(static_shaped_slice, nullptr);
|
||||
EXPECT_THAT(static_shaped_slice, m_dynamic_slice);
|
||||
}
|
||||
|
||||
TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) {
|
||||
Scope root = Scope::NewRootScope()
|
||||
.ExitOnError()
|
||||
.WithAssignedDevice(kDeviceName)
|
||||
.WithXlaCluster("cluster_0");
|
||||
|
||||
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
||||
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
|
||||
Output size = ops::Const(root.WithOpName("size"), {-1, 500});
|
||||
Output control_pred = ops::Placeholder(root.WithOpName("control"), DT_BOOL);
|
||||
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
|
||||
root.graph()->AddControlEdge(control_pred.node(), slice.node());
|
||||
|
||||
std::unique_ptr<Graph> result;
|
||||
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
|
||||
|
||||
Node* static_shaped_slice = testing::FindNodeByName(
|
||||
result.get(), "slice/static_shaped_slice/static_shaped_slice");
|
||||
ASSERT_NE(static_shaped_slice, nullptr);
|
||||
EXPECT_THAT(static_shaped_slice,
|
||||
NodeWith(Op("Slice"),
|
||||
CtrlDeps(NodeWith(Op("Placeholder"), Name("control")))));
|
||||
}
|
||||
|
||||
TEST(SliceToDynamicSliceRewriteTest, Int64Indices) {
|
||||
Scope root = Scope::NewRootScope()
|
||||
.ExitOnError()
|
||||
.WithAssignedDevice(kDeviceName)
|
||||
.WithXlaCluster("cluster_0");
|
||||
|
||||
auto to_int64 = [](int v) { return static_cast<int64>(v); };
|
||||
|
||||
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
||||
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
|
||||
Output size =
|
||||
ops::Const(root.WithOpName("size"), {to_int64(-1), to_int64(500)});
|
||||
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
|
||||
|
||||
std::unique_ptr<Graph> result;
|
||||
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
|
||||
|
||||
EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("Cast")))));
|
||||
}
|
||||
|
||||
TEST(SliceToDynamicSliceRewriteTest, DontRewriteInvalidSlice) {
|
||||
Scope root = Scope::NewRootScope()
|
||||
.ExitOnError()
|
||||
.WithAssignedDevice(kDeviceName)
|
||||
.WithXlaCluster("cluster_0");
|
||||
|
||||
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
||||
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
|
||||
|
||||
// The shape refiner throws an error if we use a bogus constant value for
|
||||
// size. So we first use a Placeholder to placate the shape refiner, and
|
||||
// later replace it with a bogus constant.
|
||||
Output size_placeholder =
|
||||
ops::Placeholder(root.WithOpName("size_placeholder"), DT_INT32);
|
||||
Output slice =
|
||||
ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder);
|
||||
|
||||
Output size = ops::Const(root.WithOpName("size"), {-8, 500});
|
||||
TF_ASSERT_OK(root.graph()->UpdateEdge(/*new_src=*/size.node(),
|
||||
/*new_src_index=*/0,
|
||||
/*dst=*/slice.node(), /*dst_index=*/2));
|
||||
|
||||
std::unique_ptr<Graph> result;
|
||||
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
|
||||
|
||||
EXPECT_THAT(result->nodes(),
|
||||
Not(Contains(NodeWith(Op("Slice"),
|
||||
Attr(kXlaCompileTimeConstantInputsAttr)))));
|
||||
}
|
||||
|
||||
TEST(SliceToDynamicSliceRewriteTest, DontRewriteUnclusteredSlice) {
|
||||
Scope root =
|
||||
Scope::NewRootScope().ExitOnError().WithAssignedDevice(kDeviceName);
|
||||
|
||||
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
||||
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
|
||||
Output size = ops::Const(root.WithOpName("size"), {-1, 500});
|
||||
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
|
||||
|
||||
std::unique_ptr<Graph> result;
|
||||
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
|
||||
|
||||
EXPECT_THAT(result->nodes(),
|
||||
Not(Contains(NodeWith(Op("Slice"),
|
||||
Attr(kXlaCompileTimeConstantInputsAttr)))));
|
||||
}
|
||||
|
||||
TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithNonConstSize) {
|
||||
Scope root = Scope::NewRootScope()
|
||||
.ExitOnError()
|
||||
.WithAssignedDevice(kDeviceName)
|
||||
.WithXlaCluster("cluster_0");
|
||||
|
||||
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
||||
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
|
||||
Output size = ops::Placeholder(root.WithOpName("size"), DT_INT64);
|
||||
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
|
||||
|
||||
std::unique_ptr<Graph> result;
|
||||
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
|
||||
|
||||
EXPECT_THAT(result->nodes(),
|
||||
Not(Contains(NodeWith(Op("Slice"),
|
||||
Attr(kXlaCompileTimeConstantInputsAttr)))));
|
||||
}
|
||||
|
||||
TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) {
|
||||
Scope root = Scope::NewRootScope()
|
||||
.ExitOnError()
|
||||
.WithAssignedDevice(kDeviceName)
|
||||
.WithXlaCluster("cluster_0");
|
||||
|
||||
auto to_int64 = [](int v) { return static_cast<int64>(v); };
|
||||
|
||||
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
||||
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
|
||||
|
||||
// The C++ node bindings immediately error out when we try construct a bogus
|
||||
// slice so we first use a placeholder to construct the Slice and then replace
|
||||
// the input.
|
||||
Output size_placeholder = ops::Placeholder(root.WithOpName("size"), DT_INT64);
|
||||
Output slice =
|
||||
ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder);
|
||||
|
||||
Output size =
|
||||
ops::Const(root.WithOpName("size"), {{to_int64(-1)}, {to_int64(500)}});
|
||||
TF_ASSERT_OK(root.graph()->UpdateEdge(size.node(), 0, slice.node(), 2));
|
||||
|
||||
std::unique_ptr<Graph> result;
|
||||
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
|
||||
|
||||
EXPECT_THAT(result->nodes(),
|
||||
Not(Contains(NodeWith(Op("Slice"),
|
||||
Attr(kXlaCompileTimeConstantInputsAttr)))));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
|
||||
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
@ -44,17 +45,20 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
|
||||
MarkForCompilationPass);
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
|
||||
IncreaseDynamismForAutoJitPass);
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
|
||||
PartiallyDeclusterPass);
|
||||
|
||||
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
|
||||
// also need to run it after the graph been rewritten to have _Send nodes added
|
||||
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
|
||||
// name, and encapsulation might remove that node from the graph.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
|
||||
EncapsulateSubgraphsPass);
|
||||
|
||||
// Must run after EncapsulateSubgraphsPass.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50,
|
||||
BuildXlaOpsPass);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -204,11 +204,12 @@ struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (!AreAttrValuesEqual(it->second, attr_kv_pair.second)) {
|
||||
if (attr_kv_pair.second &&
|
||||
!AreAttrValuesEqual(it->second, *attr_kv_pair.second)) {
|
||||
if (listener->IsInterested()) {
|
||||
*listener << "attribute named " << attr_kv_pair.first
|
||||
<< " does not match value; expected: \""
|
||||
<< SummarizeAttrValue(attr_kv_pair.second)
|
||||
<< SummarizeAttrValue(*attr_kv_pair.second)
|
||||
<< "\", found: \"" << SummarizeAttrValue(it->second)
|
||||
<< "\"";
|
||||
}
|
||||
@ -278,11 +279,13 @@ struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
|
||||
if (!attrs.empty()) {
|
||||
printed_something = true;
|
||||
std::vector<string> attrs_str;
|
||||
absl::c_transform(attrs, std::back_inserter(attrs_str),
|
||||
[](const std::pair<string, AttrValue>& attr_kv_pair) {
|
||||
return absl::StrCat(
|
||||
attr_kv_pair.first, "->",
|
||||
SummarizeAttrValue(attr_kv_pair.second));
|
||||
absl::c_transform(
|
||||
attrs, std::back_inserter(attrs_str),
|
||||
[](const std::pair<string, absl::optional<AttrValue>>& attr_kv_pair) {
|
||||
return absl::StrCat(attr_kv_pair.first, "->",
|
||||
attr_kv_pair.second
|
||||
? SummarizeAttrValue(*attr_kv_pair.second)
|
||||
: "*");
|
||||
});
|
||||
*os << " and attr values matching [" << absl::StrJoin(attrs_str, ", ")
|
||||
<< "]";
|
||||
@ -327,7 +330,7 @@ struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
|
||||
absl::optional<std::vector<::testing::Matcher<OutEdge>>> input_matchers;
|
||||
absl::optional<::testing::Matcher<absl::Span<const Node* const>>>
|
||||
control_dep_set;
|
||||
std::map<string, AttrValue> attrs;
|
||||
std::map<string, absl::optional<AttrValue>> attrs;
|
||||
};
|
||||
|
||||
// Matches a dst and dst_output on an input edge. Today we only use this with
|
||||
@ -472,12 +475,28 @@ std::pair<string, AttrValue> impl::AttrLiteralHelper(
|
||||
return {bool_attr.first, attr_value};
|
||||
}
|
||||
|
||||
std::pair<string, AttrValue> impl::AttrLiteralHelper(
|
||||
const std::pair<string, absl::Span<const int>>& int_list_attr) {
|
||||
AttrValue attr_value;
|
||||
AttrValue::ListValue* list = attr_value.mutable_list();
|
||||
for (int i : int_list_attr.second) {
|
||||
list->add_i(i);
|
||||
}
|
||||
return {int_list_attr.first, attr_value};
|
||||
}
|
||||
|
||||
impl::NodeMatcherProperties impl::Attr(std::pair<string, AttrValue> attr) {
|
||||
impl::NodeMatcherProperties props;
|
||||
props.set_attr(std::move(attr));
|
||||
return props;
|
||||
}
|
||||
|
||||
impl::NodeMatcherProperties impl::Attr(string name) {
|
||||
impl::NodeMatcherProperties props;
|
||||
props.set_attr({std::move(name), absl::nullopt});
|
||||
return props;
|
||||
}
|
||||
|
||||
NodeMatcherProperties ConstantValue(
|
||||
const ::tensorflow::Input::Initializer& val) {
|
||||
TF_CHECK_OK(val.status);
|
||||
@ -486,9 +505,9 @@ NodeMatcherProperties ConstantValue(
|
||||
return props;
|
||||
}
|
||||
|
||||
::testing::Matcher<const Node*> Const(
|
||||
::testing::Matcher<impl::OutEdge> Const(
|
||||
const ::tensorflow::Input::Initializer& val) {
|
||||
return NodeWith(ConstantValue(val));
|
||||
return Out(NodeWith(ConstantValue(val)));
|
||||
}
|
||||
::testing::Matcher<impl::OutEdge> Out(
|
||||
int oidx, ::testing::Matcher<const Node*> node_matcher) {
|
||||
|
@ -84,7 +84,7 @@ class NodeMatcherProperties {
|
||||
public:
|
||||
using NodeSeqMatcher = std::vector<::testing::Matcher<const Node*>>;
|
||||
using InputSeqMatcher = std::vector<::testing::Matcher<OutEdge>>;
|
||||
using AttrKeyValuePair = std::pair<string, AttrValue>;
|
||||
using AttrKeyValuePair = std::pair<string, absl::optional<AttrValue>>;
|
||||
|
||||
const absl::optional<string>& name() const { return name_; }
|
||||
const absl::optional<string>& op() const { return op_; }
|
||||
@ -163,9 +163,13 @@ impl::NodeMatcherProperties CtrlDeps(
|
||||
absl::Span<const ::testing::Matcher<const Node*>> control_deps);
|
||||
|
||||
impl::NodeMatcherProperties Attr(std::pair<string, AttrValue> attrs);
|
||||
impl::NodeMatcherProperties Attr(string name);
|
||||
|
||||
std::pair<string, AttrValue> AttrLiteralHelper(
|
||||
const std::pair<string, bool>& bool_attr);
|
||||
|
||||
std::pair<string, AttrValue> AttrLiteralHelper(
|
||||
const std::pair<string, absl::Span<const int>>& int_list_attr);
|
||||
} // namespace impl
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
@ -187,6 +191,10 @@ impl::NodeMatcherProperties Attr(const string& name, ValueTy value) {
|
||||
return impl::Attr({impl::AttrLiteralHelper({name, value})});
|
||||
}
|
||||
|
||||
inline impl::NodeMatcherProperties Attr(const string& name) {
|
||||
return impl::Attr(name);
|
||||
}
|
||||
|
||||
// Matches a node with inputs `inputs`.
|
||||
//
|
||||
// `inputs` are ordered; `inputs`[i] must match input i.
|
||||
@ -200,7 +208,8 @@ impl::NodeMatcherProperties Inputs(Ts... inputs) {
|
||||
::testing::Matcher<const Node*> node);
|
||||
|
||||
// Matches the first output of a node that matches `node`.
|
||||
::testing::Matcher<impl::OutEdge> Out(::testing::Matcher<const Node*> node) {
|
||||
inline ::testing::Matcher<impl::OutEdge> Out(
|
||||
::testing::Matcher<const Node*> node) {
|
||||
return Out(0, node);
|
||||
}
|
||||
|
||||
@ -224,7 +233,7 @@ template <typename... Ts>
|
||||
return impl::NodeWith(array);
|
||||
}
|
||||
|
||||
::testing::Matcher<const Node*> Const(
|
||||
::testing::Matcher<impl::OutEdge> Const(
|
||||
const ::tensorflow::Input::Initializer& val);
|
||||
} // namespace matchers
|
||||
|
||||
|
@ -104,6 +104,11 @@ NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeBuilder& NodeBuilder::XlaCluster(StringPiece xla_cluster) {
|
||||
def_builder_.Attr("_XlaCluster", xla_cluster);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
|
||||
// In case of error, set *created_node to nullptr.
|
||||
if (created_node != nullptr) *created_node = nullptr;
|
||||
|
@ -103,6 +103,9 @@ class NodeBuilder {
|
||||
// Sets the device name in the "assigned device" field in tensorflow::Node.
|
||||
NodeBuilder& AssignedDevice(StringPiece device);
|
||||
|
||||
// Sets the _XlaCluster attribute in created node to `xla_cluster`.
|
||||
NodeBuilder& XlaCluster(StringPiece xla_cluster);
|
||||
|
||||
// Set the value of an attr. attr_name must match the name of one of
|
||||
// attrs defined by the Op, and value must have the corresponding type
|
||||
// (see SetAttrValue() in ../framework/attr_value_util.h for legal
|
||||
|
@ -480,4 +480,16 @@ std::vector<string> DeviceNameUtils::GetLocalNamesForDeviceMappings(
|
||||
}
|
||||
}
|
||||
|
||||
/*static*/ Status DeviceNameUtils::DeviceNameToCpuDeviceName(
|
||||
const string& device_name, string* host_device_name) {
|
||||
DeviceNameUtils::ParsedName device;
|
||||
if (!DeviceNameUtils::ParseFullName(device_name, &device)) {
|
||||
return errors::Internal("Could not parse device name ", device_name);
|
||||
}
|
||||
device.type = "CPU";
|
||||
device.id = 0;
|
||||
*host_device_name = DeviceNameUtils::ParsedNameToString(device);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -169,6 +169,11 @@ class DeviceNameUtils {
|
||||
// mapping.
|
||||
static std::vector<string> GetLocalNamesForDeviceMappings(
|
||||
const ParsedName& pn);
|
||||
|
||||
// Returns name of the CPU:0 device on the same host as the device
|
||||
// `device_name`.
|
||||
static Status DeviceNameToCpuDeviceName(const string& device_name,
|
||||
string* host_device_name);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user