From 920fefa2eeffddd491845d798836a6800668ded9 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Tue, 31 Mar 2020 02:30:57 -0700 Subject: [PATCH] [XLA] Add a utility to transform dimensions from one shape's dimensions to another. PiperOrigin-RevId: 303923477 Change-Id: Ie0127562ba07ad97058b7db95d66d729fa433c42 --- tensorflow/compiler/xla/util.cc | 36 +++++++++++++++++++++++++++++++++ tensorflow/compiler/xla/util.h | 13 ++++++++++++ 2 files changed, 49 insertions(+) diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 6711779cd2b..1fbce96625b 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/strings/match.h" @@ -28,6 +29,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/errors.h" @@ -301,6 +303,40 @@ absl::InlinedVector, 8> CommonFactors( return bounds; } +ConvertedDimensionNumbers ConvertDimensionNumbers( + absl::Span from_dimensions, absl::Span from_sizes, + absl::Span to_sizes) { + ConvertedDimensionNumbers dimensions; + auto common_factors = CommonFactors(from_sizes, to_sizes); + for (int64 i = 0; i < common_factors.size() - 1; ++i) { + bool any_present = false; + bool all_present = true; + for (int64 d = common_factors[i].first; d < common_factors[i + 1].first; + ++d) { + const bool present = absl::c_linear_search(from_dimensions, d); + any_present |= present; + all_present &= present; + } + if (all_present) { + for (int64 d = common_factors[i].second; d < common_factors[i + 1].second; + ++d) { + dimensions.to_dimensions.push_back(d); + } + for (int64 d = common_factors[i].first; d < common_factors[i + 1].first; + ++d) { + dimensions.transformed_from_dimensions.push_back(d); + } + } else if (any_present) { + for (int64 d = common_factors[i].first; d < common_factors[i + 1].first; + ++d) { + if (absl::c_linear_search(from_dimensions, d)) { + dimensions.untransformed_from_dimensions.push_back(d); + } + } + } + } + return dimensions; +} string SanitizeFileName(string file_name) { for (char& c : file_name) { if (c == '/' || c == '\\' || c == '[' || c == ']' || c == ' ') { diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 3ef41249d24..44a5bf4ea33 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -29,6 +29,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/status.h" @@ -506,6 +507,18 @@ int64 Product(absl::Span xs); absl::InlinedVector, 8> CommonFactors( absl::Span a, absl::Span b); +struct ConvertedDimensionNumbers { + DimensionVector transformed_from_dimensions; + DimensionVector untransformed_from_dimensions; + DimensionVector to_dimensions; +}; + +// Convert and unsorted list of dimensions from one shapes dimension sizes to +// another shapes dimensions sizes. +ConvertedDimensionNumbers ConvertDimensionNumbers( + absl::Span from_dimensions, absl::Span from_sizes, + absl::Span to_sizes); + // Removes illegal characters from filenames. string SanitizeFileName(string file_name);