From fb695b89e018f36855108f33177252f329bfc8ac Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Sat, 13 Apr 2019 13:10:36 -0700 Subject: [PATCH] [XLA] Migrate a function from AlgebraicSimplifier to ShapeUtil PiperOrigin-RevId: 243440029 --- .../xla/service/algebraic_simplifier.cc | 35 ++----------------- tensorflow/compiler/xla/shape_util.cc | 26 ++++++++++++++ tensorflow/compiler/xla/shape_util.h | 14 ++++++++ 3 files changed, 43 insertions(+), 32 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 9eb6e9310dd..dc20123ee5f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1992,40 +1992,11 @@ Status AlgebraicSimplifierVisitor::HandleGetTupleElement( namespace { -// Return whether the given reshape instruction leaves the dimensions at the -// given input indices unmodified, and returns their output indices. -// -// Example: -// input_dim_indices = {2, 3} -// input shape = T[a, b, x, y, cd] -// output shape = T[ab, x, 1, y, c, d] -// return value = {1, 3} -// -// Precondition: input_dim_indices is sorted. absl::optional> ReshapeLeavesDimensionsUnmodified( const HloInstruction* hlo, absl::Span input_dim_indices) { - CHECK_EQ(HloOpcode::kReshape, hlo->opcode()); - CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())); - - std::vector output_dim_indices; - std::vector> unmodified_dims = - ShapeUtil::DimensionsUnmodifiedByReshape(hlo->operand(0)->shape(), - hlo->shape()); - size_t i = 0; // index to unmodified_dims - for (int64 input_dim_index : input_dim_indices) { - // Search unmodified_dims for input_dim_index. We can search from the last - // matching position because input_dim_indices is guaranteed to be sorted. - while (i < unmodified_dims.size() && - unmodified_dims[i].first < input_dim_index) { - ++i; - } - if (i >= unmodified_dims.size() || - unmodified_dims[i].first != input_dim_index) { - return absl::nullopt; - } - output_dim_indices.push_back(unmodified_dims[i].second); - } - return output_dim_indices; + CHECK_EQ(hlo->opcode(), HloOpcode::kReshape); + return ShapeUtil::ReshapeLeavesDimensionsUnmodified( + hlo->operand(0)->shape(), hlo->shape(), input_dim_indices); } // Returns true if the output of "instruction" is a permutation of the diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index de3b58ff46c..340f6793ccd 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1078,6 +1078,32 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return common_factors; } +/* static */ absl::optional> +ShapeUtil::ReshapeLeavesDimensionsUnmodified( + const Shape& from_shape, const Shape& to_shape, + absl::Span input_dim_indices) { + CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())); + + std::vector output_dim_indices; + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(from_shape, to_shape); + size_t i = 0; // index to unmodified_dims + for (int64 input_dim_index : input_dim_indices) { + // Search unmodified_dims for input_dim_index. We can search from the last + // matching position because input_dim_indices is guaranteed to be sorted. + while (i < unmodified_dims.size() && + unmodified_dims[i].first < input_dim_index) { + ++i; + } + if (i >= unmodified_dims.size() || + unmodified_dims[i].first != input_dim_index) { + return absl::nullopt; + } + output_dim_indices.push_back(unmodified_dims[i].second); + } + return output_dim_indices; +} + /* static */ bool ShapeUtil::TransposeIsBitcast( const Shape& input_shape, const Shape& output_shape, absl::Span dimension_mapping) { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index ee24e39f052..0065a3b8784 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -578,6 +578,20 @@ class ShapeUtil { static std::vector> DimensionsUnmodifiedByReshape( const Shape& input_shape, const Shape& output_shape); + // Return whether the given reshape instruction leaves the dimensions at the + // given input indices unmodified, and returns their output indices. + // + // Example: + // input_dim_indices = {2, 3} + // input shape = T[a, b, x, y, cd] + // output shape = T[ab, x, 1, y, c, d] + // return value = {1, 3} + // + // Precondition: input_dim_indices is sorted. + static absl::optional> ReshapeLeavesDimensionsUnmodified( + const Shape& from_shape, const Shape& to_shape, + absl::Span input_dim_indices); + // Returns whether a transpose from input_shape to output_shape with dimension // mapping "dimension_mapping" produces a result which is bit-wise identical // to its input and thus may be replaced with a bitcast.