[XLA] Turn constant dynamic slices into slices.
PiperOrigin-RevId: 308358674 Change-Id: I92f1674325c4824f858b1183455135e97e44bebc
This commit is contained in:
parent
47a28473e4
commit
490288d631
@ -449,7 +449,6 @@ cc_library(
|
|||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"@com_google_absl//absl/types:optional",
|
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "absl/types/optional.h"
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/index_util.h"
|
#include "tensorflow/compiler/xla/index_util.h"
|
||||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
@ -199,34 +198,6 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) {
|
|||||||
return literal;
|
return literal;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::optional<int64> LiteralBase::GetFirstInteger() const {
|
|
||||||
switch (shape().element_type()) {
|
|
||||||
case U8:
|
|
||||||
return GetFirstElement<uint8>();
|
|
||||||
case U16:
|
|
||||||
return GetFirstElement<uint16>();
|
|
||||||
case U32:
|
|
||||||
return GetFirstElement<uint32>();
|
|
||||||
case U64: {
|
|
||||||
int64 v = GetFirstElement<uint64>();
|
|
||||||
if (v < 0) {
|
|
||||||
return absl::nullopt;
|
|
||||||
}
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
case S8:
|
|
||||||
return GetFirstElement<int8>();
|
|
||||||
case S16:
|
|
||||||
return GetFirstElement<int16>();
|
|
||||||
case S32:
|
|
||||||
return GetFirstElement<int32>();
|
|
||||||
case S64:
|
|
||||||
return GetFirstElement<int64>();
|
|
||||||
default:
|
|
||||||
return absl::nullopt;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
Status MutableLiteralBase::CopySliceFromInternal(
|
Status MutableLiteralBase::CopySliceFromInternal(
|
||||||
const LiteralBase& src_literal, absl::Span<const int64> src_base,
|
const LiteralBase& src_literal, absl::Span<const int64> src_base,
|
||||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/optional.h"
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/array2d.h"
|
#include "tensorflow/compiler/xla/array2d.h"
|
||||||
#include "tensorflow/compiler/xla/array3d.h"
|
#include "tensorflow/compiler/xla/array3d.h"
|
||||||
@ -117,9 +116,6 @@ class LiteralBase {
|
|||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
NativeT GetFirstElement() const;
|
NativeT GetFirstElement() const;
|
||||||
|
|
||||||
// As above but returns any integer type casted to an int64.
|
|
||||||
absl::optional<int64> GetFirstInteger() const;
|
|
||||||
|
|
||||||
// As Get(), but determines the correct type and converts the value
|
// As Get(), but determines the correct type and converts the value
|
||||||
// into text.
|
// into text.
|
||||||
string GetAsString(absl::Span<const int64> multi_index,
|
string GetAsString(absl::Span<const int64> multi_index,
|
||||||
|
60
tensorflow/compiler/xla/service/algebraic_simplifier.cc
Executable file → Normal file
60
tensorflow/compiler/xla/service/algebraic_simplifier.cc
Executable file → Normal file
@ -816,8 +816,6 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
|
|||||||
// Concatenate the indices and updates
|
// Concatenate the indices and updates
|
||||||
if (index_concat_is_safe && same_dimension_numbers &&
|
if (index_concat_is_safe && same_dimension_numbers &&
|
||||||
index_concat_dimension &&
|
index_concat_dimension &&
|
||||||
lhs_scatter_index->shape().element_type() ==
|
|
||||||
rhs_scatter_index->shape().element_type() &&
|
|
||||||
ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) {
|
ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) {
|
||||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
|
TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
|
||||||
MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
|
MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
|
||||||
@ -3638,39 +3636,6 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
|
|||||||
MakeBroadcastHlo(new_dynamic_slice, operand->dimensions(),
|
MakeBroadcastHlo(new_dynamic_slice, operand->dimensions(),
|
||||||
dynamic_slice->shape()));
|
dynamic_slice->shape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert a dynamic slice into a slice if all offsets are constant and the
|
|
||||||
// operand is not constant. If ev
|
|
||||||
if (operand->opcode() != HloOpcode::kConstant &&
|
|
||||||
absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1,
|
|
||||||
dynamic_slice->operands().end()),
|
|
||||||
[](HloInstruction* operand) {
|
|
||||||
return operand->opcode() == HloOpcode::kConstant &&
|
|
||||||
ShapeUtil::ElementIsIntegral(operand->shape());
|
|
||||||
})) {
|
|
||||||
const int64 rank = operand->shape().rank();
|
|
||||||
std::vector<int64> slice_starts(rank);
|
|
||||||
std::vector<int64> slice_limits(rank);
|
|
||||||
std::vector<int64> slice_strides(rank, 1);
|
|
||||||
|
|
||||||
for (int64 i = 0; i < rank; ++i) {
|
|
||||||
absl::optional<int64> offset =
|
|
||||||
dynamic_slice->operand(i + 1)->literal().GetFirstInteger();
|
|
||||||
if (!offset || *offset < 0) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
const int64 max_offset =
|
|
||||||
dynamic_slice->operand(0)->shape().dimensions(i) -
|
|
||||||
dynamic_slice->shape().dimensions(i);
|
|
||||||
slice_starts[i] = std::min(max_offset, *offset);
|
|
||||||
slice_limits[i] =
|
|
||||||
std::min(max_offset, *offset) + dynamic_slice->shape().dimensions(i);
|
|
||||||
}
|
|
||||||
return ReplaceWithNewInstruction(
|
|
||||||
dynamic_slice,
|
|
||||||
HloInstruction::CreateSlice(dynamic_slice->shape(), operand,
|
|
||||||
slice_starts, slice_limits, slice_strides));
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3705,8 +3670,8 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
|
|||||||
compatible = false;
|
compatible = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
PaddingConfig padding_config;
|
|
||||||
if (compatible) {
|
if (compatible) {
|
||||||
|
PaddingConfig padding_config;
|
||||||
for (int64 dim = 0; dim < updated_shape.rank(); ++dim) {
|
for (int64 dim = 0; dim < updated_shape.rank(); ++dim) {
|
||||||
auto padding_config_dim = padding_config.add_dimensions();
|
auto padding_config_dim = padding_config.add_dimensions();
|
||||||
auto slice_dim_start = update_start_indx->operand(dim + offset);
|
auto slice_dim_start = update_start_indx->operand(dim + offset);
|
||||||
@ -3715,23 +3680,24 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
VLOG(2) << "slice :" << slice_dim_start->ToString();
|
VLOG(2) << "slice :" << slice_dim_start->ToString();
|
||||||
absl::optional<int64> beg =
|
int64 beg;
|
||||||
slice_dim_start->literal().GetFirstInteger();
|
if (slice_dim_start->shape().element_type() == S32) {
|
||||||
if (!beg) {
|
beg = slice_dim_start->literal().Get<int32>({});
|
||||||
|
} else if (slice_dim_start->shape().element_type() == U32) {
|
||||||
|
beg = slice_dim_start->literal().Get<uint32>({});
|
||||||
|
} else {
|
||||||
compatible = false;
|
compatible = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
VLOG(2) << "beg value:" << *beg;
|
VLOG(2) << "beg value:" << beg;
|
||||||
auto update_width = ShapeUtil::GetDimension(update_shape, dim);
|
auto update_width = ShapeUtil::GetDimension(update_shape, dim);
|
||||||
auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim);
|
auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim);
|
||||||
padding_config_dim->set_edge_padding_low(*beg);
|
padding_config_dim->set_edge_padding_low(beg);
|
||||||
padding_config_dim->set_edge_padding_high(
|
padding_config_dim->set_edge_padding_high(
|
||||||
std::max(bcast_width - (*beg + update_width), int64{0}));
|
std::max(bcast_width - (beg + update_width), int64{0}));
|
||||||
// dynamic_update_slice does not specify a stride
|
// dynamic_update_slice does not specify a stride
|
||||||
padding_config_dim->set_interior_padding(0);
|
padding_config_dim->set_interior_padding(0);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (compatible) {
|
if (compatible) {
|
||||||
HloInstruction* pad =
|
HloInstruction* pad =
|
||||||
computation_->AddInstruction(HloInstruction::CreatePad(
|
computation_->AddInstruction(HloInstruction::CreatePad(
|
||||||
@ -3740,7 +3706,11 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
|
|||||||
VLOG(2) << " with pad:" << pad->ToString();
|
VLOG(2) << " with pad:" << pad->ToString();
|
||||||
VLOG(2) << " Computation before rewrite is: "
|
VLOG(2) << " Computation before rewrite is: "
|
||||||
<< dynamic_update_slice->parent()->ToString();
|
<< dynamic_update_slice->parent()->ToString();
|
||||||
return ReplaceInstruction(dynamic_update_slice, pad);
|
auto res = ReplaceInstruction(dynamic_update_slice, pad);
|
||||||
|
VLOG(2) << " Computation after rewrite is: "
|
||||||
|
<< pad->parent()->ToString();
|
||||||
|
return res;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4183,31 +4183,6 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
|
|||||||
EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter()));
|
EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(AlgebraicSimplifierTest, ConstantDynamicSlice) {
|
|
||||||
auto m = CreateNewVerifiedModule();
|
|
||||||
HloComputation::Builder builder(TestName());
|
|
||||||
|
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
|
|
||||||
std::vector<HloInstruction*> params;
|
|
||||||
for (int i = 0; i < 3; ++i) {
|
|
||||||
params.push_back(builder.AddInstruction(HloInstruction::CreateConstant(
|
|
||||||
LiteralUtil::CreateR0<int32>(2 << (i + 1)))));
|
|
||||||
}
|
|
||||||
Shape ds_shape = ShapeUtil::MakeShape(F32, {2, 20, 200});
|
|
||||||
builder.AddInstruction(HloInstruction::CreateDynamicSlice(
|
|
||||||
ds_shape,
|
|
||||||
builder.AddInstruction(
|
|
||||||
HloInstruction::CreateParameter(0, shape, "operand")),
|
|
||||||
params,
|
|
||||||
/*slice_sizes=*/{2, 20, 200}));
|
|
||||||
|
|
||||||
auto computation = m->AddEntryComputation(builder.Build());
|
|
||||||
AlgebraicSimplifier simplifier(default_options_);
|
|
||||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
|
||||||
EXPECT_THAT(computation->root_instruction(),
|
|
||||||
GmockMatch(m::Slice(m::Parameter())));
|
|
||||||
}
|
|
||||||
|
|
||||||
// A dynamic-update-slice is trivial if its start indices are all zeroes and the
|
// A dynamic-update-slice is trivial if its start indices are all zeroes and the
|
||||||
// size of its "update" equals the size of its output. In this case, the
|
// size of its "update" equals the size of its output. In this case, the
|
||||||
// dynamic-update-slice is equal to its update.
|
// dynamic-update-slice is equal to its update.
|
||||||
|
Loading…
Reference in New Issue
Block a user