diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 1350f9e3e0b..a2993058321 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -449,7 +449,6 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index cbbad741ce3..44e6a3c7bdb 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -27,7 +27,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" -#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -199,34 +198,6 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) { return literal; } -absl::optional LiteralBase::GetFirstInteger() const { - switch (shape().element_type()) { - case U8: - return GetFirstElement(); - case U16: - return GetFirstElement(); - case U32: - return GetFirstElement(); - case U64: { - int64 v = GetFirstElement(); - if (v < 0) { - return absl::nullopt; - } - return v; - } - case S8: - return GetFirstElement(); - case S16: - return GetFirstElement(); - case S32: - return GetFirstElement(); - case S64: - return GetFirstElement(); - default: - return absl::nullopt; - } -} - template Status MutableLiteralBase::CopySliceFromInternal( const LiteralBase& src_literal, absl::Span src_base, diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 1553d042e80..7aee34437e6 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -27,7 +27,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/string_view.h" -#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" @@ -117,9 +116,6 @@ class LiteralBase { template NativeT GetFirstElement() const; - // As above but returns any integer type casted to an int64. - absl::optional GetFirstInteger() const; - // As Get(), but determines the correct type and converts the value // into text. string GetAsString(absl::Span multi_index, diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc old mode 100755 new mode 100644 index 55af8726dc8..1fbb48669a3 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -816,8 +816,6 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { // Concatenate the indices and updates if (index_concat_is_safe && same_dimension_numbers && index_concat_dimension && - lhs_scatter_index->shape().element_type() == - rhs_scatter_index->shape().element_type() && ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) { TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, @@ -3638,39 +3636,6 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( MakeBroadcastHlo(new_dynamic_slice, operand->dimensions(), dynamic_slice->shape())); } - - // Convert a dynamic slice into a slice if all offsets are constant and the - // operand is not constant. If ev - if (operand->opcode() != HloOpcode::kConstant && - absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1, - dynamic_slice->operands().end()), - [](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConstant && - ShapeUtil::ElementIsIntegral(operand->shape()); - })) { - const int64 rank = operand->shape().rank(); - std::vector slice_starts(rank); - std::vector slice_limits(rank); - std::vector slice_strides(rank, 1); - - for (int64 i = 0; i < rank; ++i) { - absl::optional offset = - dynamic_slice->operand(i + 1)->literal().GetFirstInteger(); - if (!offset || *offset < 0) { - return Status::OK(); - } - const int64 max_offset = - dynamic_slice->operand(0)->shape().dimensions(i) - - dynamic_slice->shape().dimensions(i); - slice_starts[i] = std::min(max_offset, *offset); - slice_limits[i] = - std::min(max_offset, *offset) + dynamic_slice->shape().dimensions(i); - } - return ReplaceWithNewInstruction( - dynamic_slice, - HloInstruction::CreateSlice(dynamic_slice->shape(), operand, - slice_starts, slice_limits, slice_strides)); - } return Status::OK(); } @@ -3705,8 +3670,8 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( compatible = false; } } - PaddingConfig padding_config; if (compatible) { + PaddingConfig padding_config; for (int64 dim = 0; dim < updated_shape.rank(); ++dim) { auto padding_config_dim = padding_config.add_dimensions(); auto slice_dim_start = update_start_indx->operand(dim + offset); @@ -3715,32 +3680,37 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( break; } VLOG(2) << "slice :" << slice_dim_start->ToString(); - absl::optional beg = - slice_dim_start->literal().GetFirstInteger(); - if (!beg) { + int64 beg; + if (slice_dim_start->shape().element_type() == S32) { + beg = slice_dim_start->literal().Get({}); + } else if (slice_dim_start->shape().element_type() == U32) { + beg = slice_dim_start->literal().Get({}); + } else { compatible = false; break; } - VLOG(2) << "beg value:" << *beg; + VLOG(2) << "beg value:" << beg; auto update_width = ShapeUtil::GetDimension(update_shape, dim); auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim); - padding_config_dim->set_edge_padding_low(*beg); + padding_config_dim->set_edge_padding_low(beg); padding_config_dim->set_edge_padding_high( - std::max(bcast_width - (*beg + update_width), int64{0})); + std::max(bcast_width - (beg + update_width), int64{0})); // dynamic_update_slice does not specify a stride padding_config_dim->set_interior_padding(0); } - } - - if (compatible) { - HloInstruction* pad = - computation_->AddInstruction(HloInstruction::CreatePad( - updated_shape, dus_update, pad_value, padding_config)); - VLOG(2) << dynamic_update_slice->ToString(); - VLOG(2) << " with pad:" << pad->ToString(); - VLOG(2) << " Computation before rewrite is: " - << dynamic_update_slice->parent()->ToString(); - return ReplaceInstruction(dynamic_update_slice, pad); + if (compatible) { + HloInstruction* pad = + computation_->AddInstruction(HloInstruction::CreatePad( + updated_shape, dus_update, pad_value, padding_config)); + VLOG(2) << dynamic_update_slice->ToString(); + VLOG(2) << " with pad:" << pad->ToString(); + VLOG(2) << " Computation before rewrite is: " + << dynamic_update_slice->parent()->ToString(); + auto res = ReplaceInstruction(dynamic_update_slice, pad); + VLOG(2) << " Computation after rewrite is: " + << pad->parent()->ToString(); + return res; + } } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 6c8e80aa963..5604146b6cc 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -4183,31 +4183,6 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter())); } -TEST_F(AlgebraicSimplifierTest, ConstantDynamicSlice) { - auto m = CreateNewVerifiedModule(); - HloComputation::Builder builder(TestName()); - - Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); - std::vector params; - for (int i = 0; i < 3; ++i) { - params.push_back(builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(2 << (i + 1))))); - } - Shape ds_shape = ShapeUtil::MakeShape(F32, {2, 20, 200}); - builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ds_shape, - builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "operand")), - params, - /*slice_sizes=*/{2, 20, 200})); - - auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(default_options_); - ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - GmockMatch(m::Slice(m::Parameter()))); -} - // A dynamic-update-slice is trivial if its start indices are all zeroes and the // size of its "update" equals the size of its output. In this case, the // dynamic-update-slice is equal to its update.