[XLA] Add a utility to transform dimensions from one shape's dimensions to another.

PiperOrigin-RevId: 303923477
Change-Id: Ie0127562ba07ad97058b7db95d66d729fa433c42
This commit is contained in:
Blake Hechtman 2020-03-31 02:30:57 -07:00 committed by TensorFlower Gardener
parent 28aa08fc1e
commit 920fefa2ee
2 changed files with 49 additions and 0 deletions
tensorflow/compiler/xla

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <limits>
#include <numeric>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/match.h"
@ -28,6 +29,7 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/core/errors.h"
@ -301,6 +303,40 @@ absl::InlinedVector<std::pair<int64, int64>, 8> CommonFactors(
return bounds;
}
ConvertedDimensionNumbers ConvertDimensionNumbers(
absl::Span<const int64> from_dimensions, absl::Span<const int64> from_sizes,
absl::Span<const int64> to_sizes) {
ConvertedDimensionNumbers dimensions;
auto common_factors = CommonFactors(from_sizes, to_sizes);
for (int64 i = 0; i < common_factors.size() - 1; ++i) {
bool any_present = false;
bool all_present = true;
for (int64 d = common_factors[i].first; d < common_factors[i + 1].first;
++d) {
const bool present = absl::c_linear_search(from_dimensions, d);
any_present |= present;
all_present &= present;
}
if (all_present) {
for (int64 d = common_factors[i].second; d < common_factors[i + 1].second;
++d) {
dimensions.to_dimensions.push_back(d);
}
for (int64 d = common_factors[i].first; d < common_factors[i + 1].first;
++d) {
dimensions.transformed_from_dimensions.push_back(d);
}
} else if (any_present) {
for (int64 d = common_factors[i].first; d < common_factors[i + 1].first;
++d) {
if (absl::c_linear_search(from_dimensions, d)) {
dimensions.untransformed_from_dimensions.push_back(d);
}
}
}
}
return dimensions;
}
string SanitizeFileName(string file_name) {
for (char& c : file_name) {
if (c == '/' || c == '\\' || c == '[' || c == ']' || c == ' ') {

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/status.h"
@ -506,6 +507,18 @@ int64 Product(absl::Span<const int64> xs);
absl::InlinedVector<std::pair<int64, int64>, 8> CommonFactors(
absl::Span<const int64> a, absl::Span<const int64> b);
struct ConvertedDimensionNumbers {
DimensionVector transformed_from_dimensions;
DimensionVector untransformed_from_dimensions;
DimensionVector to_dimensions;
};
// Convert and unsorted list of dimensions from one shapes dimension sizes to
// another shapes dimensions sizes.
ConvertedDimensionNumbers ConvertDimensionNumbers(
absl::Span<const int64> from_dimensions, absl::Span<const int64> from_sizes,
absl::Span<const int64> to_sizes);
// Removes illegal characters from filenames.
string SanitizeFileName(string file_name);