[XLA] Turn constant dynamic slices into slices.
PiperOrigin-RevId: 308444591 Change-Id: I5ef204447fd6e689920cb18e56e1e7dbae014548
This commit is contained in:
parent
339be9d20e
commit
beef8eadbc
@ -449,6 +449,7 @@ cc_library(
|
|||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/index_util.h"
|
#include "tensorflow/compiler/xla/index_util.h"
|
||||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
@ -198,6 +199,34 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) {
|
|||||||
return literal;
|
return literal;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::optional<int64> LiteralBase::GetFirstInteger() const {
|
||||||
|
switch (shape().element_type()) {
|
||||||
|
case U8:
|
||||||
|
return GetFirstElement<uint8>();
|
||||||
|
case U16:
|
||||||
|
return GetFirstElement<uint16>();
|
||||||
|
case U32:
|
||||||
|
return GetFirstElement<uint32>();
|
||||||
|
case U64: {
|
||||||
|
int64 v = GetFirstElement<uint64>();
|
||||||
|
if (v < 0) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
case S8:
|
||||||
|
return GetFirstElement<int8>();
|
||||||
|
case S16:
|
||||||
|
return GetFirstElement<int16>();
|
||||||
|
case S32:
|
||||||
|
return GetFirstElement<int32>();
|
||||||
|
case S64:
|
||||||
|
return GetFirstElement<int64>();
|
||||||
|
default:
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
Status MutableLiteralBase::CopySliceFromInternal(
|
Status MutableLiteralBase::CopySliceFromInternal(
|
||||||
const LiteralBase& src_literal, absl::Span<const int64> src_base,
|
const LiteralBase& src_literal, absl::Span<const int64> src_base,
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/array2d.h"
|
#include "tensorflow/compiler/xla/array2d.h"
|
||||||
#include "tensorflow/compiler/xla/array3d.h"
|
#include "tensorflow/compiler/xla/array3d.h"
|
||||||
@ -116,6 +117,9 @@ class LiteralBase {
|
|||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
NativeT GetFirstElement() const;
|
NativeT GetFirstElement() const;
|
||||||
|
|
||||||
|
// As above but returns any integer type casted to an int64.
|
||||||
|
absl::optional<int64> GetFirstInteger() const;
|
||||||
|
|
||||||
// As Get(), but determines the correct type and converts the value
|
// As Get(), but determines the correct type and converts the value
|
||||||
// into text.
|
// into text.
|
||||||
string GetAsString(absl::Span<const int64> multi_index,
|
string GetAsString(absl::Span<const int64> multi_index,
|
||||||
|
76
tensorflow/compiler/xla/service/algebraic_simplifier.cc
Normal file → Executable file
76
tensorflow/compiler/xla/service/algebraic_simplifier.cc
Normal file → Executable file
@ -816,6 +816,8 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
|
|||||||
// Concatenate the indices and updates
|
// Concatenate the indices and updates
|
||||||
if (index_concat_is_safe && same_dimension_numbers &&
|
if (index_concat_is_safe && same_dimension_numbers &&
|
||||||
index_concat_dimension &&
|
index_concat_dimension &&
|
||||||
|
lhs_scatter_index->shape().element_type() ==
|
||||||
|
rhs_scatter_index->shape().element_type() &&
|
||||||
ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) {
|
ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) {
|
||||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
|
TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
|
||||||
MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
|
MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
|
||||||
@ -3636,6 +3638,39 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
|
|||||||
MakeBroadcastHlo(new_dynamic_slice, operand->dimensions(),
|
MakeBroadcastHlo(new_dynamic_slice, operand->dimensions(),
|
||||||
dynamic_slice->shape()));
|
dynamic_slice->shape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert a dynamic slice into a slice if all offsets are constant and the
|
||||||
|
// operand is not constant. If ev
|
||||||
|
if (operand->opcode() != HloOpcode::kConstant &&
|
||||||
|
absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1,
|
||||||
|
dynamic_slice->operands().end()),
|
||||||
|
[](HloInstruction* operand) {
|
||||||
|
return operand->opcode() == HloOpcode::kConstant &&
|
||||||
|
ShapeUtil::ElementIsIntegral(operand->shape());
|
||||||
|
})) {
|
||||||
|
const int64 rank = operand->shape().rank();
|
||||||
|
std::vector<int64> slice_starts(rank);
|
||||||
|
std::vector<int64> slice_limits(rank);
|
||||||
|
std::vector<int64> slice_strides(rank, 1);
|
||||||
|
|
||||||
|
for (int64 i = 0; i < rank; ++i) {
|
||||||
|
absl::optional<int64> offset =
|
||||||
|
dynamic_slice->operand(i + 1)->literal().GetFirstInteger();
|
||||||
|
if (!offset || *offset < 0) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
const int64 max_offset =
|
||||||
|
dynamic_slice->operand(0)->shape().dimensions(i) -
|
||||||
|
dynamic_slice->shape().dimensions(i);
|
||||||
|
slice_starts[i] = std::min(max_offset, *offset);
|
||||||
|
slice_limits[i] =
|
||||||
|
std::min(max_offset, *offset) + dynamic_slice->shape().dimensions(i);
|
||||||
|
}
|
||||||
|
return ReplaceWithNewInstruction(
|
||||||
|
dynamic_slice,
|
||||||
|
HloInstruction::CreateSlice(dynamic_slice->shape(), operand,
|
||||||
|
slice_starts, slice_limits, slice_strides));
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3670,8 +3705,8 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
|
|||||||
compatible = false;
|
compatible = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
PaddingConfig padding_config;
|
||||||
if (compatible) {
|
if (compatible) {
|
||||||
PaddingConfig padding_config;
|
|
||||||
for (int64 dim = 0; dim < updated_shape.rank(); ++dim) {
|
for (int64 dim = 0; dim < updated_shape.rank(); ++dim) {
|
||||||
auto padding_config_dim = padding_config.add_dimensions();
|
auto padding_config_dim = padding_config.add_dimensions();
|
||||||
auto slice_dim_start = update_start_indx->operand(dim + offset);
|
auto slice_dim_start = update_start_indx->operand(dim + offset);
|
||||||
@ -3680,37 +3715,32 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
VLOG(2) << "slice :" << slice_dim_start->ToString();
|
VLOG(2) << "slice :" << slice_dim_start->ToString();
|
||||||
int64 beg;
|
absl::optional<int64> beg =
|
||||||
if (slice_dim_start->shape().element_type() == S32) {
|
slice_dim_start->literal().GetFirstInteger();
|
||||||
beg = slice_dim_start->literal().Get<int32>({});
|
if (!beg) {
|
||||||
} else if (slice_dim_start->shape().element_type() == U32) {
|
|
||||||
beg = slice_dim_start->literal().Get<uint32>({});
|
|
||||||
} else {
|
|
||||||
compatible = false;
|
compatible = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
VLOG(2) << "beg value:" << beg;
|
VLOG(2) << "beg value:" << *beg;
|
||||||
auto update_width = ShapeUtil::GetDimension(update_shape, dim);
|
auto update_width = ShapeUtil::GetDimension(update_shape, dim);
|
||||||
auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim);
|
auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim);
|
||||||
padding_config_dim->set_edge_padding_low(beg);
|
padding_config_dim->set_edge_padding_low(*beg);
|
||||||
padding_config_dim->set_edge_padding_high(
|
padding_config_dim->set_edge_padding_high(
|
||||||
std::max(bcast_width - (beg + update_width), int64{0}));
|
std::max(bcast_width - (*beg + update_width), int64{0}));
|
||||||
// dynamic_update_slice does not specify a stride
|
// dynamic_update_slice does not specify a stride
|
||||||
padding_config_dim->set_interior_padding(0);
|
padding_config_dim->set_interior_padding(0);
|
||||||
}
|
}
|
||||||
if (compatible) {
|
}
|
||||||
HloInstruction* pad =
|
|
||||||
computation_->AddInstruction(HloInstruction::CreatePad(
|
if (compatible) {
|
||||||
updated_shape, dus_update, pad_value, padding_config));
|
HloInstruction* pad =
|
||||||
VLOG(2) << dynamic_update_slice->ToString();
|
computation_->AddInstruction(HloInstruction::CreatePad(
|
||||||
VLOG(2) << " with pad:" << pad->ToString();
|
updated_shape, dus_update, pad_value, padding_config));
|
||||||
VLOG(2) << " Computation before rewrite is: "
|
VLOG(2) << dynamic_update_slice->ToString();
|
||||||
<< dynamic_update_slice->parent()->ToString();
|
VLOG(2) << " with pad:" << pad->ToString();
|
||||||
auto res = ReplaceInstruction(dynamic_update_slice, pad);
|
VLOG(2) << " Computation before rewrite is: "
|
||||||
VLOG(2) << " Computation after rewrite is: "
|
<< dynamic_update_slice->parent()->ToString();
|
||||||
<< pad->parent()->ToString();
|
return ReplaceInstruction(dynamic_update_slice, pad);
|
||||||
return res;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4183,6 +4183,31 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
|
|||||||
EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter()));
|
EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, ConstantDynamicSlice) {
|
||||||
|
auto m = CreateNewVerifiedModule();
|
||||||
|
HloComputation::Builder builder(TestName());
|
||||||
|
|
||||||
|
Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
|
||||||
|
std::vector<HloInstruction*> params;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
params.push_back(builder.AddInstruction(HloInstruction::CreateConstant(
|
||||||
|
LiteralUtil::CreateR0<int32>(2 << (i + 1)))));
|
||||||
|
}
|
||||||
|
Shape ds_shape = ShapeUtil::MakeShape(F32, {2, 20, 200});
|
||||||
|
builder.AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||||
|
ds_shape,
|
||||||
|
builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, shape, "operand")),
|
||||||
|
params,
|
||||||
|
/*slice_sizes=*/{2, 20, 200}));
|
||||||
|
|
||||||
|
auto computation = m->AddEntryComputation(builder.Build());
|
||||||
|
AlgebraicSimplifier simplifier(default_options_);
|
||||||
|
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(computation->root_instruction(),
|
||||||
|
GmockMatch(m::Slice(m::Parameter())));
|
||||||
|
}
|
||||||
|
|
||||||
// A dynamic-update-slice is trivial if its start indices are all zeroes and the
|
// A dynamic-update-slice is trivial if its start indices are all zeroes and the
|
||||||
// size of its "update" equals the size of its output. In this case, the
|
// size of its "update" equals the size of its output. In this case, the
|
||||||
// dynamic-update-slice is equal to its update.
|
// dynamic-update-slice is equal to its update.
|
||||||
|
Loading…
Reference in New Issue
Block a user