diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index a2993058321..1350f9e3e0b 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -449,6 +449,7 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 44e6a3c7bdb..cbbad741ce3 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -198,6 +199,34 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) { return literal; } +absl::optional LiteralBase::GetFirstInteger() const { + switch (shape().element_type()) { + case U8: + return GetFirstElement(); + case U16: + return GetFirstElement(); + case U32: + return GetFirstElement(); + case U64: { + int64 v = GetFirstElement(); + if (v < 0) { + return absl::nullopt; + } + return v; + } + case S8: + return GetFirstElement(); + case S16: + return GetFirstElement(); + case S32: + return GetFirstElement(); + case S64: + return GetFirstElement(); + default: + return absl::nullopt; + } +} + template Status MutableLiteralBase::CopySliceFromInternal( const LiteralBase& src_literal, absl::Span src_base, diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 7aee34437e6..1553d042e80 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" @@ -116,6 +117,9 @@ class LiteralBase { template NativeT GetFirstElement() const; + // As above but returns any integer type casted to an int64. + absl::optional GetFirstInteger() const; + // As Get(), but determines the correct type and converts the value // into text. string GetAsString(absl::Span multi_index, diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc old mode 100644 new mode 100755 index 1fbb48669a3..55af8726dc8 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -816,6 +816,8 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { // Concatenate the indices and updates if (index_concat_is_safe && same_dimension_numbers && index_concat_dimension && + lhs_scatter_index->shape().element_type() == + rhs_scatter_index->shape().element_type() && ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) { TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, @@ -3636,6 +3638,39 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( MakeBroadcastHlo(new_dynamic_slice, operand->dimensions(), dynamic_slice->shape())); } + + // Convert a dynamic slice into a slice if all offsets are constant and the + // operand is not constant. If ev + if (operand->opcode() != HloOpcode::kConstant && + absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1, + dynamic_slice->operands().end()), + [](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConstant && + ShapeUtil::ElementIsIntegral(operand->shape()); + })) { + const int64 rank = operand->shape().rank(); + std::vector slice_starts(rank); + std::vector slice_limits(rank); + std::vector slice_strides(rank, 1); + + for (int64 i = 0; i < rank; ++i) { + absl::optional offset = + dynamic_slice->operand(i + 1)->literal().GetFirstInteger(); + if (!offset || *offset < 0) { + return Status::OK(); + } + const int64 max_offset = + dynamic_slice->operand(0)->shape().dimensions(i) - + dynamic_slice->shape().dimensions(i); + slice_starts[i] = std::min(max_offset, *offset); + slice_limits[i] = + std::min(max_offset, *offset) + dynamic_slice->shape().dimensions(i); + } + return ReplaceWithNewInstruction( + dynamic_slice, + HloInstruction::CreateSlice(dynamic_slice->shape(), operand, + slice_starts, slice_limits, slice_strides)); + } return Status::OK(); } @@ -3670,8 +3705,8 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( compatible = false; } } + PaddingConfig padding_config; if (compatible) { - PaddingConfig padding_config; for (int64 dim = 0; dim < updated_shape.rank(); ++dim) { auto padding_config_dim = padding_config.add_dimensions(); auto slice_dim_start = update_start_indx->operand(dim + offset); @@ -3680,37 +3715,32 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( break; } VLOG(2) << "slice :" << slice_dim_start->ToString(); - int64 beg; - if (slice_dim_start->shape().element_type() == S32) { - beg = slice_dim_start->literal().Get({}); - } else if (slice_dim_start->shape().element_type() == U32) { - beg = slice_dim_start->literal().Get({}); - } else { + absl::optional beg = + slice_dim_start->literal().GetFirstInteger(); + if (!beg) { compatible = false; break; } - VLOG(2) << "beg value:" << beg; + VLOG(2) << "beg value:" << *beg; auto update_width = ShapeUtil::GetDimension(update_shape, dim); auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim); - padding_config_dim->set_edge_padding_low(beg); + padding_config_dim->set_edge_padding_low(*beg); padding_config_dim->set_edge_padding_high( - std::max(bcast_width - (beg + update_width), int64{0})); + std::max(bcast_width - (*beg + update_width), int64{0})); // dynamic_update_slice does not specify a stride padding_config_dim->set_interior_padding(0); } - if (compatible) { - HloInstruction* pad = - computation_->AddInstruction(HloInstruction::CreatePad( - updated_shape, dus_update, pad_value, padding_config)); - VLOG(2) << dynamic_update_slice->ToString(); - VLOG(2) << " with pad:" << pad->ToString(); - VLOG(2) << " Computation before rewrite is: " - << dynamic_update_slice->parent()->ToString(); - auto res = ReplaceInstruction(dynamic_update_slice, pad); - VLOG(2) << " Computation after rewrite is: " - << pad->parent()->ToString(); - return res; - } + } + + if (compatible) { + HloInstruction* pad = + computation_->AddInstruction(HloInstruction::CreatePad( + updated_shape, dus_update, pad_value, padding_config)); + VLOG(2) << dynamic_update_slice->ToString(); + VLOG(2) << " with pad:" << pad->ToString(); + VLOG(2) << " Computation before rewrite is: " + << dynamic_update_slice->parent()->ToString(); + return ReplaceInstruction(dynamic_update_slice, pad); } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 5604146b6cc..6c8e80aa963 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -4183,6 +4183,31 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter())); } +TEST_F(AlgebraicSimplifierTest, ConstantDynamicSlice) { + auto m = CreateNewVerifiedModule(); + HloComputation::Builder builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); + std::vector params; + for (int i = 0; i < 3; ++i) { + params.push_back(builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(2 << (i + 1))))); + } + Shape ds_shape = ShapeUtil::MakeShape(F32, {2, 20, 200}); + builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ds_shape, + builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "operand")), + params, + /*slice_sizes=*/{2, 20, 200})); + + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Parameter()))); +} + // A dynamic-update-slice is trivial if its start indices are all zeroes and the // size of its "update" equals the size of its output. In this case, the // dynamic-update-slice is equal to its update.