[XLA] Add a utility to transform dimensions from one shape's dimensions to another.
PiperOrigin-RevId: 303923477 Change-Id: Ie0127562ba07ad97058b7db95d66d729fa433c42
This commit is contained in:
parent
28aa08fc1e
commit
920fefa2ee
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/match.h"
|
||||
@ -28,6 +29,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -301,6 +303,40 @@ absl::InlinedVector<std::pair<int64, int64>, 8> CommonFactors(
|
||||
return bounds;
|
||||
}
|
||||
|
||||
ConvertedDimensionNumbers ConvertDimensionNumbers(
|
||||
absl::Span<const int64> from_dimensions, absl::Span<const int64> from_sizes,
|
||||
absl::Span<const int64> to_sizes) {
|
||||
ConvertedDimensionNumbers dimensions;
|
||||
auto common_factors = CommonFactors(from_sizes, to_sizes);
|
||||
for (int64 i = 0; i < common_factors.size() - 1; ++i) {
|
||||
bool any_present = false;
|
||||
bool all_present = true;
|
||||
for (int64 d = common_factors[i].first; d < common_factors[i + 1].first;
|
||||
++d) {
|
||||
const bool present = absl::c_linear_search(from_dimensions, d);
|
||||
any_present |= present;
|
||||
all_present &= present;
|
||||
}
|
||||
if (all_present) {
|
||||
for (int64 d = common_factors[i].second; d < common_factors[i + 1].second;
|
||||
++d) {
|
||||
dimensions.to_dimensions.push_back(d);
|
||||
}
|
||||
for (int64 d = common_factors[i].first; d < common_factors[i + 1].first;
|
||||
++d) {
|
||||
dimensions.transformed_from_dimensions.push_back(d);
|
||||
}
|
||||
} else if (any_present) {
|
||||
for (int64 d = common_factors[i].first; d < common_factors[i + 1].first;
|
||||
++d) {
|
||||
if (absl::c_linear_search(from_dimensions, d)) {
|
||||
dimensions.untransformed_from_dimensions.push_back(d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return dimensions;
|
||||
}
|
||||
string SanitizeFileName(string file_name) {
|
||||
for (char& c : file_name) {
|
||||
if (c == '/' || c == '\\' || c == '[' || c == ']' || c == ' ') {
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
@ -506,6 +507,18 @@ int64 Product(absl::Span<const int64> xs);
|
||||
absl::InlinedVector<std::pair<int64, int64>, 8> CommonFactors(
|
||||
absl::Span<const int64> a, absl::Span<const int64> b);
|
||||
|
||||
struct ConvertedDimensionNumbers {
|
||||
DimensionVector transformed_from_dimensions;
|
||||
DimensionVector untransformed_from_dimensions;
|
||||
DimensionVector to_dimensions;
|
||||
};
|
||||
|
||||
// Convert and unsorted list of dimensions from one shapes dimension sizes to
|
||||
// another shapes dimensions sizes.
|
||||
ConvertedDimensionNumbers ConvertDimensionNumbers(
|
||||
absl::Span<const int64> from_dimensions, absl::Span<const int64> from_sizes,
|
||||
absl::Span<const int64> to_sizes);
|
||||
|
||||
// Removes illegal characters from filenames.
|
||||
string SanitizeFileName(string file_name);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user