Don't rewrite slices that have a constant begin input

PiperOrigin-RevId: 232746272
This commit is contained in:
Sanjoy Das 2019-02-06 14:24:03 -08:00 committed by TensorFlower Gardener
parent c855674062
commit 0064ad1e1c
2 changed files with 33 additions and 5 deletions

View File

@ -247,6 +247,7 @@ Status ConvertTensorFlowSliceToStaticShapedSlice(
.NewSubScope(absl::StrCat(slice->name(), "/static_shaped_slice"));
Scope host_scope = main_scope.WithAssignedDevice(host_name);
// In the future we may want to be clever here and avoid the extra Cast ops.
SliceInputs slice_inputs_int64 =
MakeSliceIndexAndSizeInt64(host_scope, slice_inputs);
@ -312,9 +313,9 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs,
return Status::OK();
}
// Return true if `n` is a slice we can rewrite to have a static shape
// Return true if `n` is a slice we should rewrite to have a static shape
// (i.e. have the output shape only depend on the "size" input).
xla::StatusOr<bool> IsRewritableSlice(Node* n) {
xla::StatusOr<bool> ShouldRewriteSlice(Node* n) {
if (n->type_string() != "Slice") {
return false;
}
@ -332,14 +333,20 @@ xla::StatusOr<bool> IsRewritableSlice(Node* n) {
// If slice_size[i] < -1 for any i then executing the slice will throw an
// error, and we don't do anything here.
return absl::c_all_of(slice_inputs->size_as_vector,
[](int64 size_i) { return size_i >= -1; });
bool slice_size_has_error = absl::c_all_of(
slice_inputs->size_as_vector, [](int64 size_i) { return size_i >= -1; });
if (!slice_size_has_error) {
return false;
}
// No point in rewriting slices that have both size and begin as constants.
return !slice_inputs->begin.node()->IsConstant();
}
Status FindAndRewriteSlices(Graph* g, bool* changed) {
std::vector<Node*> slices_to_rewrite;
for (Node* n : g->nodes()) {
TF_ASSIGN_OR_RETURN(bool is_rewritable, IsRewritableSlice(n));
TF_ASSIGN_OR_RETURN(bool is_rewritable, ShouldRewriteSlice(n));
if (is_rewritable) {
slices_to_rewrite.push_back(n);
}

View File

@ -432,5 +432,26 @@ TEST(SliceToDynamicSliceRewriteTest, WithControlDepsToConstant) {
Name("dependency")))));
}
TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithConstBegin) {
Scope root = Scope::NewRootScope()
.ExitOnError()
.WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0");
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Const(root.WithOpName("begin"), {10, 10});
Output size = ops::Const(root.WithOpName("size"), {-1, 500});
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
std::unique_ptr<Graph> result;
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
Node* slice_node = testing::FindNodeByName(result.get(), "slice");
EXPECT_THAT(slice_node,
NodeWith(Op("Slice"), Inputs(Out(NodeWith(Op("Placeholder"))),
Out(NodeWith(Op("Const"))),
Out(NodeWith(Op("Const"))))));
}
} // namespace
} // namespace tensorflow