[XLA] Turn constant dynamic slices into slices.
PiperOrigin-RevId: 308358674 Change-Id: I92f1674325c4824f858b1183455135e97e44bebc
This commit is contained in:
parent
47a28473e4
commit
490288d631
@ -449,7 +449,6 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/index_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
@ -199,34 +198,6 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) {
|
||||
return literal;
|
||||
}
|
||||
|
||||
absl::optional<int64> LiteralBase::GetFirstInteger() const {
|
||||
switch (shape().element_type()) {
|
||||
case U8:
|
||||
return GetFirstElement<uint8>();
|
||||
case U16:
|
||||
return GetFirstElement<uint16>();
|
||||
case U32:
|
||||
return GetFirstElement<uint32>();
|
||||
case U64: {
|
||||
int64 v = GetFirstElement<uint64>();
|
||||
if (v < 0) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
return v;
|
||||
}
|
||||
case S8:
|
||||
return GetFirstElement<int8>();
|
||||
case S16:
|
||||
return GetFirstElement<int16>();
|
||||
case S32:
|
||||
return GetFirstElement<int32>();
|
||||
case S64:
|
||||
return GetFirstElement<int64>();
|
||||
default:
|
||||
return absl::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
Status MutableLiteralBase::CopySliceFromInternal(
|
||||
const LiteralBase& src_literal, absl::Span<const int64> src_base,
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/array3d.h"
|
||||
@ -117,9 +116,6 @@ class LiteralBase {
|
||||
template <typename NativeT>
|
||||
NativeT GetFirstElement() const;
|
||||
|
||||
// As above but returns any integer type casted to an int64.
|
||||
absl::optional<int64> GetFirstInteger() const;
|
||||
|
||||
// As Get(), but determines the correct type and converts the value
|
||||
// into text.
|
||||
string GetAsString(absl::Span<const int64> multi_index,
|
||||
|
60
tensorflow/compiler/xla/service/algebraic_simplifier.cc
Executable file → Normal file
60
tensorflow/compiler/xla/service/algebraic_simplifier.cc
Executable file → Normal file
@ -816,8 +816,6 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
|
||||
// Concatenate the indices and updates
|
||||
if (index_concat_is_safe && same_dimension_numbers &&
|
||||
index_concat_dimension &&
|
||||
lhs_scatter_index->shape().element_type() ==
|
||||
rhs_scatter_index->shape().element_type() &&
|
||||
ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) {
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
|
||||
MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
|
||||
@ -3638,39 +3636,6 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
|
||||
MakeBroadcastHlo(new_dynamic_slice, operand->dimensions(),
|
||||
dynamic_slice->shape()));
|
||||
}
|
||||
|
||||
// Convert a dynamic slice into a slice if all offsets are constant and the
|
||||
// operand is not constant. If ev
|
||||
if (operand->opcode() != HloOpcode::kConstant &&
|
||||
absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1,
|
||||
dynamic_slice->operands().end()),
|
||||
[](HloInstruction* operand) {
|
||||
return operand->opcode() == HloOpcode::kConstant &&
|
||||
ShapeUtil::ElementIsIntegral(operand->shape());
|
||||
})) {
|
||||
const int64 rank = operand->shape().rank();
|
||||
std::vector<int64> slice_starts(rank);
|
||||
std::vector<int64> slice_limits(rank);
|
||||
std::vector<int64> slice_strides(rank, 1);
|
||||
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
absl::optional<int64> offset =
|
||||
dynamic_slice->operand(i + 1)->literal().GetFirstInteger();
|
||||
if (!offset || *offset < 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
const int64 max_offset =
|
||||
dynamic_slice->operand(0)->shape().dimensions(i) -
|
||||
dynamic_slice->shape().dimensions(i);
|
||||
slice_starts[i] = std::min(max_offset, *offset);
|
||||
slice_limits[i] =
|
||||
std::min(max_offset, *offset) + dynamic_slice->shape().dimensions(i);
|
||||
}
|
||||
return ReplaceWithNewInstruction(
|
||||
dynamic_slice,
|
||||
HloInstruction::CreateSlice(dynamic_slice->shape(), operand,
|
||||
slice_starts, slice_limits, slice_strides));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -3705,8 +3670,8 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
|
||||
compatible = false;
|
||||
}
|
||||
}
|
||||
PaddingConfig padding_config;
|
||||
if (compatible) {
|
||||
PaddingConfig padding_config;
|
||||
for (int64 dim = 0; dim < updated_shape.rank(); ++dim) {
|
||||
auto padding_config_dim = padding_config.add_dimensions();
|
||||
auto slice_dim_start = update_start_indx->operand(dim + offset);
|
||||
@ -3715,23 +3680,24 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
|
||||
break;
|
||||
}
|
||||
VLOG(2) << "slice :" << slice_dim_start->ToString();
|
||||
absl::optional<int64> beg =
|
||||
slice_dim_start->literal().GetFirstInteger();
|
||||
if (!beg) {
|
||||
int64 beg;
|
||||
if (slice_dim_start->shape().element_type() == S32) {
|
||||
beg = slice_dim_start->literal().Get<int32>({});
|
||||
} else if (slice_dim_start->shape().element_type() == U32) {
|
||||
beg = slice_dim_start->literal().Get<uint32>({});
|
||||
} else {
|
||||
compatible = false;
|
||||
break;
|
||||
}
|
||||
VLOG(2) << "beg value:" << *beg;
|
||||
VLOG(2) << "beg value:" << beg;
|
||||
auto update_width = ShapeUtil::GetDimension(update_shape, dim);
|
||||
auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim);
|
||||
padding_config_dim->set_edge_padding_low(*beg);
|
||||
padding_config_dim->set_edge_padding_low(beg);
|
||||
padding_config_dim->set_edge_padding_high(
|
||||
std::max(bcast_width - (*beg + update_width), int64{0}));
|
||||
std::max(bcast_width - (beg + update_width), int64{0}));
|
||||
// dynamic_update_slice does not specify a stride
|
||||
padding_config_dim->set_interior_padding(0);
|
||||
}
|
||||
}
|
||||
|
||||
if (compatible) {
|
||||
HloInstruction* pad =
|
||||
computation_->AddInstruction(HloInstruction::CreatePad(
|
||||
@ -3740,7 +3706,11 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
|
||||
VLOG(2) << " with pad:" << pad->ToString();
|
||||
VLOG(2) << " Computation before rewrite is: "
|
||||
<< dynamic_update_slice->parent()->ToString();
|
||||
return ReplaceInstruction(dynamic_update_slice, pad);
|
||||
auto res = ReplaceInstruction(dynamic_update_slice, pad);
|
||||
VLOG(2) << " Computation after rewrite is: "
|
||||
<< pad->parent()->ToString();
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4183,31 +4183,6 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
|
||||
EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter()));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ConstantDynamicSlice) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
|
||||
std::vector<HloInstruction*> params;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
params.push_back(builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<int32>(2 << (i + 1)))));
|
||||
}
|
||||
Shape ds_shape = ShapeUtil::MakeShape(F32, {2, 20, 200});
|
||||
builder.AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
ds_shape,
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "operand")),
|
||||
params,
|
||||
/*slice_sizes=*/{2, 20, 200}));
|
||||
|
||||
auto computation = m->AddEntryComputation(builder.Build());
|
||||
AlgebraicSimplifier simplifier(default_options_);
|
||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
GmockMatch(m::Slice(m::Parameter())));
|
||||
}
|
||||
|
||||
// A dynamic-update-slice is trivial if its start indices are all zeroes and the
|
||||
// size of its "update" equals the size of its output. In this case, the
|
||||
// dynamic-update-slice is equal to its update.
|
||||
|
Loading…
Reference in New Issue
Block a user