Don't rewrite slices that have a constant begin input
PiperOrigin-RevId: 232746272
This commit is contained in:
parent
c855674062
commit
0064ad1e1c
tensorflow/compiler/jit
@ -247,6 +247,7 @@ Status ConvertTensorFlowSliceToStaticShapedSlice(
|
||||
.NewSubScope(absl::StrCat(slice->name(), "/static_shaped_slice"));
|
||||
Scope host_scope = main_scope.WithAssignedDevice(host_name);
|
||||
|
||||
// In the future we may want to be clever here and avoid the extra Cast ops.
|
||||
SliceInputs slice_inputs_int64 =
|
||||
MakeSliceIndexAndSizeInt64(host_scope, slice_inputs);
|
||||
|
||||
@ -312,9 +313,9 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Return true if `n` is a slice we can rewrite to have a static shape
|
||||
// Return true if `n` is a slice we should rewrite to have a static shape
|
||||
// (i.e. have the output shape only depend on the "size" input).
|
||||
xla::StatusOr<bool> IsRewritableSlice(Node* n) {
|
||||
xla::StatusOr<bool> ShouldRewriteSlice(Node* n) {
|
||||
if (n->type_string() != "Slice") {
|
||||
return false;
|
||||
}
|
||||
@ -332,14 +333,20 @@ xla::StatusOr<bool> IsRewritableSlice(Node* n) {
|
||||
|
||||
// If slice_size[i] < -1 for any i then executing the slice will throw an
|
||||
// error, and we don't do anything here.
|
||||
return absl::c_all_of(slice_inputs->size_as_vector,
|
||||
[](int64 size_i) { return size_i >= -1; });
|
||||
bool slice_size_has_error = absl::c_all_of(
|
||||
slice_inputs->size_as_vector, [](int64 size_i) { return size_i >= -1; });
|
||||
if (!slice_size_has_error) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// No point in rewriting slices that have both size and begin as constants.
|
||||
return !slice_inputs->begin.node()->IsConstant();
|
||||
}
|
||||
|
||||
Status FindAndRewriteSlices(Graph* g, bool* changed) {
|
||||
std::vector<Node*> slices_to_rewrite;
|
||||
for (Node* n : g->nodes()) {
|
||||
TF_ASSIGN_OR_RETURN(bool is_rewritable, IsRewritableSlice(n));
|
||||
TF_ASSIGN_OR_RETURN(bool is_rewritable, ShouldRewriteSlice(n));
|
||||
if (is_rewritable) {
|
||||
slices_to_rewrite.push_back(n);
|
||||
}
|
||||
|
@ -432,5 +432,26 @@ TEST(SliceToDynamicSliceRewriteTest, WithControlDepsToConstant) {
|
||||
Name("dependency")))));
|
||||
}
|
||||
|
||||
TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithConstBegin) {
|
||||
Scope root = Scope::NewRootScope()
|
||||
.ExitOnError()
|
||||
.WithAssignedDevice(kDeviceName)
|
||||
.WithXlaCluster("cluster_0");
|
||||
|
||||
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
||||
Output begin = ops::Const(root.WithOpName("begin"), {10, 10});
|
||||
Output size = ops::Const(root.WithOpName("size"), {-1, 500});
|
||||
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
|
||||
|
||||
std::unique_ptr<Graph> result;
|
||||
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
|
||||
|
||||
Node* slice_node = testing::FindNodeByName(result.get(), "slice");
|
||||
EXPECT_THAT(slice_node,
|
||||
NodeWith(Op("Slice"), Inputs(Out(NodeWith(Op("Placeholder"))),
|
||||
Out(NodeWith(Op("Const"))),
|
||||
Out(NodeWith(Op("Const"))))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user