[XLA] Turn constant dynamic slices into slices.

PiperOrigin-RevId: 308358674
Change-Id: I92f1674325c4824f858b1183455135e97e44bebc
This commit is contained in:
A. Unique TensorFlower 2020-04-24 17:39:34 -07:00 committed by TensorFlower Gardener
parent 47a28473e4
commit 490288d631
5 changed files with 23 additions and 112 deletions

View File

@ -449,7 +449,6 @@ cc_library(
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
) )

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/primitive_util.h"
@ -199,34 +198,6 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) {
return literal; return literal;
} }
absl::optional<int64> LiteralBase::GetFirstInteger() const {
switch (shape().element_type()) {
case U8:
return GetFirstElement<uint8>();
case U16:
return GetFirstElement<uint16>();
case U32:
return GetFirstElement<uint32>();
case U64: {
int64 v = GetFirstElement<uint64>();
if (v < 0) {
return absl::nullopt;
}
return v;
}
case S8:
return GetFirstElement<int8>();
case S16:
return GetFirstElement<int16>();
case S32:
return GetFirstElement<int32>();
case S64:
return GetFirstElement<int64>();
default:
return absl::nullopt;
}
}
template <typename NativeT> template <typename NativeT>
Status MutableLiteralBase::CopySliceFromInternal( Status MutableLiteralBase::CopySliceFromInternal(
const LiteralBase& src_literal, absl::Span<const int64> src_base, const LiteralBase& src_literal, absl::Span<const int64> src_base,

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array3d.h"
@ -117,9 +116,6 @@ class LiteralBase {
template <typename NativeT> template <typename NativeT>
NativeT GetFirstElement() const; NativeT GetFirstElement() const;
// As above but returns any integer type casted to an int64.
absl::optional<int64> GetFirstInteger() const;
// As Get(), but determines the correct type and converts the value // As Get(), but determines the correct type and converts the value
// into text. // into text.
string GetAsString(absl::Span<const int64> multi_index, string GetAsString(absl::Span<const int64> multi_index,

76
tensorflow/compiler/xla/service/algebraic_simplifier.cc Executable file → Normal file
View File

@ -816,8 +816,6 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
// Concatenate the indices and updates // Concatenate the indices and updates
if (index_concat_is_safe && same_dimension_numbers && if (index_concat_is_safe && same_dimension_numbers &&
index_concat_dimension && index_concat_dimension &&
lhs_scatter_index->shape().element_type() ==
rhs_scatter_index->shape().element_type() &&
ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) { ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) {
TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
@ -3638,39 +3636,6 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
MakeBroadcastHlo(new_dynamic_slice, operand->dimensions(), MakeBroadcastHlo(new_dynamic_slice, operand->dimensions(),
dynamic_slice->shape())); dynamic_slice->shape()));
} }
// Convert a dynamic slice into a slice if all offsets are constant and the
// operand is not constant. If ev
if (operand->opcode() != HloOpcode::kConstant &&
absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1,
dynamic_slice->operands().end()),
[](HloInstruction* operand) {
return operand->opcode() == HloOpcode::kConstant &&
ShapeUtil::ElementIsIntegral(operand->shape());
})) {
const int64 rank = operand->shape().rank();
std::vector<int64> slice_starts(rank);
std::vector<int64> slice_limits(rank);
std::vector<int64> slice_strides(rank, 1);
for (int64 i = 0; i < rank; ++i) {
absl::optional<int64> offset =
dynamic_slice->operand(i + 1)->literal().GetFirstInteger();
if (!offset || *offset < 0) {
return Status::OK();
}
const int64 max_offset =
dynamic_slice->operand(0)->shape().dimensions(i) -
dynamic_slice->shape().dimensions(i);
slice_starts[i] = std::min(max_offset, *offset);
slice_limits[i] =
std::min(max_offset, *offset) + dynamic_slice->shape().dimensions(i);
}
return ReplaceWithNewInstruction(
dynamic_slice,
HloInstruction::CreateSlice(dynamic_slice->shape(), operand,
slice_starts, slice_limits, slice_strides));
}
return Status::OK(); return Status::OK();
} }
@ -3705,8 +3670,8 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
compatible = false; compatible = false;
} }
} }
PaddingConfig padding_config;
if (compatible) { if (compatible) {
PaddingConfig padding_config;
for (int64 dim = 0; dim < updated_shape.rank(); ++dim) { for (int64 dim = 0; dim < updated_shape.rank(); ++dim) {
auto padding_config_dim = padding_config.add_dimensions(); auto padding_config_dim = padding_config.add_dimensions();
auto slice_dim_start = update_start_indx->operand(dim + offset); auto slice_dim_start = update_start_indx->operand(dim + offset);
@ -3715,32 +3680,37 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
break; break;
} }
VLOG(2) << "slice :" << slice_dim_start->ToString(); VLOG(2) << "slice :" << slice_dim_start->ToString();
absl::optional<int64> beg = int64 beg;
slice_dim_start->literal().GetFirstInteger(); if (slice_dim_start->shape().element_type() == S32) {
if (!beg) { beg = slice_dim_start->literal().Get<int32>({});
} else if (slice_dim_start->shape().element_type() == U32) {
beg = slice_dim_start->literal().Get<uint32>({});
} else {
compatible = false; compatible = false;
break; break;
} }
VLOG(2) << "beg value:" << *beg; VLOG(2) << "beg value:" << beg;
auto update_width = ShapeUtil::GetDimension(update_shape, dim); auto update_width = ShapeUtil::GetDimension(update_shape, dim);
auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim); auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim);
padding_config_dim->set_edge_padding_low(*beg); padding_config_dim->set_edge_padding_low(beg);
padding_config_dim->set_edge_padding_high( padding_config_dim->set_edge_padding_high(
std::max(bcast_width - (*beg + update_width), int64{0})); std::max(bcast_width - (beg + update_width), int64{0}));
// dynamic_update_slice does not specify a stride // dynamic_update_slice does not specify a stride
padding_config_dim->set_interior_padding(0); padding_config_dim->set_interior_padding(0);
} }
} if (compatible) {
HloInstruction* pad =
if (compatible) { computation_->AddInstruction(HloInstruction::CreatePad(
HloInstruction* pad = updated_shape, dus_update, pad_value, padding_config));
computation_->AddInstruction(HloInstruction::CreatePad( VLOG(2) << dynamic_update_slice->ToString();
updated_shape, dus_update, pad_value, padding_config)); VLOG(2) << " with pad:" << pad->ToString();
VLOG(2) << dynamic_update_slice->ToString(); VLOG(2) << " Computation before rewrite is: "
VLOG(2) << " with pad:" << pad->ToString(); << dynamic_update_slice->parent()->ToString();
VLOG(2) << " Computation before rewrite is: " auto res = ReplaceInstruction(dynamic_update_slice, pad);
<< dynamic_update_slice->parent()->ToString(); VLOG(2) << " Computation after rewrite is: "
return ReplaceInstruction(dynamic_update_slice, pad); << pad->parent()->ToString();
return res;
}
} }
} }

View File

@ -4183,31 +4183,6 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter())); EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter()));
} }
TEST_F(AlgebraicSimplifierTest, ConstantDynamicSlice) {
auto m = CreateNewVerifiedModule();
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
std::vector<HloInstruction*> params;
for (int i = 0; i < 3; ++i) {
params.push_back(builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(2 << (i + 1)))));
}
Shape ds_shape = ShapeUtil::MakeShape(F32, {2, 20, 200});
builder.AddInstruction(HloInstruction::CreateDynamicSlice(
ds_shape,
builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "operand")),
params,
/*slice_sizes=*/{2, 20, 200}));
auto computation = m->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(default_options_);
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::Slice(m::Parameter())));
}
// A dynamic-update-slice is trivial if its start indices are all zeroes and the // A dynamic-update-slice is trivial if its start indices are all zeroes and the
// size of its "update" equals the size of its output. In this case, the // size of its "update" equals the size of its output. In this case, the
// dynamic-update-slice is equal to its update. // dynamic-update-slice is equal to its update.