[TF:XLA] Optimize the literal transpose operation
Optimize the literal transpose operation by avoiding item by item copies. Transposing a F32{128, 64, 64, 32} with a {0, 3, 2, 1} permutation, on a Xeon E5-1650 v3, took ~40s before, and ~130ms after. Made literal Reshape support not MonotonicDim0Major layouts. Optimized the literal Relayout operation to use the new Copy() operation, and to hence cover all the primitive types. Added unittest for the LiteralUtil::Populate() API. Change: 155265178
This commit is contained in:
parent
87ba9f5370
commit
b04d0985f3
tensorflow/compiler/xla
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
@ -308,37 +309,16 @@ template <typename T, typename WT>
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::Relayout(
|
||||
const Literal& original, const Layout& layout) {
|
||||
// Note: if this were a performance bottleneck, we avoid cloning and just make
|
||||
// an uninitialized array instead, since all values are clobbered below.
|
||||
std::unique_ptr<Literal> result = CloneToUnique(original);
|
||||
*result->mutable_shape()->mutable_layout() = layout;
|
||||
const PrimitiveType primitive_type = original.shape().element_type();
|
||||
switch (primitive_type) {
|
||||
case F32:
|
||||
LiteralUtil::EachCell<float>(
|
||||
original,
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indices, float value) {
|
||||
LiteralUtil::Set<float>(result.get(), indices, value);
|
||||
});
|
||||
return result;
|
||||
case S32:
|
||||
LiteralUtil::EachCell<int32>(
|
||||
original,
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indices, int32 value) {
|
||||
LiteralUtil::Set<int32>(result.get(), indices, value);
|
||||
});
|
||||
return result;
|
||||
case U32:
|
||||
LiteralUtil::EachCell<uint32>(
|
||||
original,
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indices, uint32 value) {
|
||||
LiteralUtil::Set<uint32>(result.get(), indices, value);
|
||||
});
|
||||
return result;
|
||||
default:
|
||||
LOG(FATAL) << "not yet implemented: "
|
||||
<< PrimitiveType_Name(primitive_type);
|
||||
}
|
||||
|
||||
const Shape& shape = original.shape();
|
||||
std::vector<int64> base(ShapeUtil::Rank(shape), 0);
|
||||
std::vector<int64> copy_size(shape.dimensions().begin(),
|
||||
shape.dimensions().end());
|
||||
|
||||
TF_CHECK_OK(Copy(original, base, result.get(), base, copy_size));
|
||||
return result;
|
||||
}
|
||||
|
||||
/* static */ StatusOr<std::unique_ptr<Literal>> LiteralUtil::Reshape(
|
||||
@ -346,25 +326,19 @@ template <typename T, typename WT>
|
||||
if (ShapeUtil::IsTuple(input.shape())) {
|
||||
return InvalidArgument("Reshape does not support tuples.");
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> output;
|
||||
if (!LayoutUtil::IsMonotonicWithDim0Major(input.shape().layout())) {
|
||||
return Unimplemented(
|
||||
"Input shape must have a monotonic layout where dimension 0 is major, "
|
||||
"was: %s",
|
||||
LayoutUtil::HumanString(input.shape().layout()).c_str());
|
||||
std::vector<int64> minor_to_major(ShapeUtil::Rank(input.shape()));
|
||||
std::iota(minor_to_major.rbegin(), minor_to_major.rend(),
|
||||
static_cast<int64>(0));
|
||||
output = Relayout(input, LayoutUtil::MakeLayout(minor_to_major));
|
||||
} else {
|
||||
output = CloneToUnique(input);
|
||||
}
|
||||
std::vector<int64> layout(dimensions.size());
|
||||
std::iota(layout.rbegin(), layout.rend(), 0);
|
||||
|
||||
// Because the layout is monotonic, we can simply reuse the same sequence of
|
||||
// values without changing their order.
|
||||
std::unique_ptr<Literal> output = CloneToUnique(input);
|
||||
output->clear_shape();
|
||||
output->mutable_shape()->set_element_type(input.shape().element_type());
|
||||
for (int64 dimension : dimensions) {
|
||||
output->mutable_shape()->add_dimensions(dimension);
|
||||
}
|
||||
*output->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(layout);
|
||||
*output->mutable_shape() =
|
||||
ShapeUtil::MakeShape(input.shape().element_type(), dimensions);
|
||||
|
||||
int64 elements_before = ShapeUtil::ElementsIn(input.shape());
|
||||
int64 elements_after = ShapeUtil::ElementsIn(output->shape());
|
||||
@ -378,73 +352,42 @@ template <typename T, typename WT>
|
||||
return std::move(output);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <class T>
|
||||
void TransposeLiteralInternal(const Literal& original,
|
||||
tensorflow::gtl::ArraySlice<int64> permutation,
|
||||
Literal* result) {
|
||||
std::vector<int64> new_indices(ShapeUtil::Rank(original.shape()));
|
||||
LiteralUtil::EachCell<T>(
|
||||
original, [&](tensorflow::gtl::ArraySlice<int64> indices, T value) {
|
||||
for (int64 i = 0; i < indices.size(); ++i) {
|
||||
new_indices[i] = indices[permutation[i]];
|
||||
}
|
||||
LiteralUtil::Set<T>(result, new_indices, value);
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::Transpose(
|
||||
const Literal& original, tensorflow::gtl::ArraySlice<int64> permutation) {
|
||||
CHECK(!ShapeUtil::IsTuple(original.shape()))
|
||||
<< "tuple is not supported for transpose";
|
||||
std::vector<int64> dimension_numbers(ShapeUtil::Rank(original.shape()));
|
||||
std::iota(dimension_numbers.begin(), dimension_numbers.end(), 0);
|
||||
CHECK(std::is_permutation(permutation.begin(), permutation.end(),
|
||||
dimension_numbers.begin()))
|
||||
<< "given permutation is not a permutation of dimension numbers";
|
||||
std::vector<int64> new_dimension_sizes;
|
||||
for (const int64 dim : permutation) {
|
||||
new_dimension_sizes.push_back(original.shape().dimensions(dim));
|
||||
}
|
||||
const auto result_shape = ShapeUtil::MakeShape(
|
||||
original.shape().element_type(), new_dimension_sizes);
|
||||
std::unique_ptr<Literal> result = CloneToUnique(original);
|
||||
*result->mutable_shape() = result_shape;
|
||||
const PrimitiveType primitive_type = original.shape().element_type();
|
||||
switch (primitive_type) {
|
||||
case F32:
|
||||
TransposeLiteralInternal<float>(original, permutation, result.get());
|
||||
return result;
|
||||
case F64:
|
||||
TransposeLiteralInternal<double>(original, permutation, result.get());
|
||||
return result;
|
||||
case PRED:
|
||||
TransposeLiteralInternal<bool>(original, permutation, result.get());
|
||||
return result;
|
||||
case S8:
|
||||
TransposeLiteralInternal<int8>(original, permutation, result.get());
|
||||
return result;
|
||||
case U8:
|
||||
TransposeLiteralInternal<uint8>(original, permutation, result.get());
|
||||
return result;
|
||||
case S32:
|
||||
TransposeLiteralInternal<int32>(original, permutation, result.get());
|
||||
return result;
|
||||
case U32:
|
||||
TransposeLiteralInternal<uint32>(original, permutation, result.get());
|
||||
return result;
|
||||
case S64:
|
||||
TransposeLiteralInternal<int64>(original, permutation, result.get());
|
||||
return result;
|
||||
case U64:
|
||||
TransposeLiteralInternal<uint64>(original, permutation, result.get());
|
||||
return result;
|
||||
default:
|
||||
LOG(FATAL) << "not yet implemented: "
|
||||
<< PrimitiveType_Name(primitive_type);
|
||||
<< "Tuple is not supported for transpose";
|
||||
CHECK(IsPermutation(permutation, ShapeUtil::Rank(original.shape())))
|
||||
<< "Given permutation is not a permutation of dimension numbers";
|
||||
// To transpose the array, we just permute the dimensions and layout, and
|
||||
// do a straight memory copy of the raw data set.
|
||||
// This is considerably faster than iterating over every array element using
|
||||
// the EachCell<>() and Set<>() APIs.
|
||||
std::vector<int64> inverse_permutation = InversePermutation(permutation);
|
||||
Shape shape =
|
||||
ShapeUtil::PermuteDimensions(inverse_permutation, original.shape());
|
||||
// Replace the layout with one affine to the original shape, such that a
|
||||
// transpose operation can be performed by leaving the flat values
|
||||
// representation intact.
|
||||
// For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
|
||||
// The shape with affine layout resulting from that operation will be
|
||||
// F32[8,11]{0,1}, since it leave the original most minor (the 8 sized), the
|
||||
// most minor.
|
||||
// Essentially, given MinMaj(Di) the position of the Di dimension within the
|
||||
// minor to major vector, and given T(Di) the index that the original Di
|
||||
// dimension has within the transposed array, a layout is affine if
|
||||
// MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
|
||||
// vector of the affine layout.
|
||||
Layout* layout = shape.mutable_layout();
|
||||
layout->clear_minor_to_major();
|
||||
for (auto index : original.shape().layout().minor_to_major()) {
|
||||
layout->add_minor_to_major(inverse_permutation[index]);
|
||||
}
|
||||
std::unique_ptr<Literal> new_literal = CreateFromShape(shape);
|
||||
DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()),
|
||||
ShapeUtil::ByteSizeOf(original.shape()));
|
||||
std::memcpy(MutableInternalData(new_literal.get()), InternalData(original),
|
||||
ShapeUtil::ByteSizeOf(original.shape()));
|
||||
return new_literal;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::Slice(
|
||||
|
@ -239,6 +239,11 @@ class LiteralUtil {
|
||||
// Clones literal into an owned unique_ptr version.
|
||||
static std::unique_ptr<Literal> CloneToUnique(const Literal& literal);
|
||||
|
||||
// Returns the linear index of the given index within the literal's
|
||||
// element_type repeated field.
|
||||
static int64 LinearIndex(const Literal& literal,
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index);
|
||||
|
||||
// Gets or sets an element in the literal at the given index. The index is
|
||||
// CHECKed against the dimension sizes.
|
||||
template <typename NativeT>
|
||||
@ -427,11 +432,6 @@ class LiteralUtil {
|
||||
"Cannot map native type to primitive type.");
|
||||
}
|
||||
|
||||
// Returns the linear index of the given index within the literal's
|
||||
// element_type repeated field.
|
||||
static int64 LinearIndex(const Literal& literal,
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index);
|
||||
|
||||
// Internal template helper for the Copy() API, matching its arguments one by
|
||||
// one.
|
||||
//
|
||||
|
@ -469,6 +469,26 @@ TEST_F(LiteralUtilTest, ReshapeR4) {
|
||||
EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
|
||||
// clang-format off
|
||||
// F32[1x3x2x4]
|
||||
auto original = LiteralUtil::CreateR4WithLayout<float>({{
|
||||
{{10, 11, 12, 13}, {14, 15, 16, 17}},
|
||||
{{18, 19, 20, 21}, {22, 23, 24, 25}},
|
||||
{{26, 27, 28, 29}, {30, 31, 32, 33}},
|
||||
}}, layout_r4_dim0minor_);
|
||||
// F32[1x3x4x2]
|
||||
auto expected = LiteralUtil::CreateR3WithLayout<float>({
|
||||
{{10, 11}, {12, 13}, {14, 15}, {16, 17}},
|
||||
{{18, 19}, {20, 21}, {22, 23}, {24, 25}},
|
||||
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
|
||||
}, layout_r3_dim0major_);
|
||||
// clang-format on
|
||||
auto reshape = LiteralUtil::Reshape(*original, {3, 4, 2}).ConsumeValueOrDie();
|
||||
|
||||
EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, TransposeR0) {
|
||||
auto original = LiteralUtil::CreateR0<float>(1.7f);
|
||||
auto reshape = LiteralUtil::Transpose(*original, /*permutation=*/{});
|
||||
@ -659,15 +679,15 @@ TEST_F(LiteralUtilTest, Copy) {
|
||||
primitive_util::NativeToPrimitiveType<uint32>(), dimensions, layout);
|
||||
auto blank = LiteralUtil::CreateFromShape(shape);
|
||||
auto source = LiteralUtil::CreateFromShape(shape);
|
||||
const int64 sbase[] = {0, 0, 0, 0};
|
||||
const int64 incr[] = {1, 1, 1, 1};
|
||||
const int64 zero_base[] = {0, 0, 0, 0};
|
||||
const int64 step[] = {1, 1, 1, 1};
|
||||
uint32 seqnr = 0;
|
||||
auto init_proc = [&](const std::vector<int64>& indexes) {
|
||||
LiteralUtil::Set(source.get(), indexes, ++seqnr);
|
||||
return true;
|
||||
};
|
||||
|
||||
ShapeUtil::ForEachIndex(source->shape(), sbase, dimensions, incr,
|
||||
ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step,
|
||||
init_proc);
|
||||
|
||||
const int64 src_base[] = {3, 1, 5, 7};
|
||||
@ -691,7 +711,7 @@ TEST_F(LiteralUtilTest, Copy) {
|
||||
bval == LiteralUtil::Get<uint32>(*source, source_indexes));
|
||||
return matched;
|
||||
};
|
||||
ShapeUtil::ForEachIndex(source->shape(), sbase, copy_size, incr,
|
||||
ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step,
|
||||
check_proc);
|
||||
EXPECT_TRUE(matched);
|
||||
}
|
||||
@ -710,5 +730,43 @@ TEST_F(LiteralUtilTest, CopyScalars) {
|
||||
EXPECT_EQ(LiteralUtil::Get<uint32>(*vect, {4}), 17);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, Populate) {
|
||||
struct PopulateData {
|
||||
std::vector<int64> dimensions;
|
||||
std::vector<int64> layout;
|
||||
} populate_data[] = {
|
||||
{{}, {}},
|
||||
{{16}, {0}},
|
||||
{{4, 16}, {1, 0}},
|
||||
{{21, 12}, {0, 1}},
|
||||
{{6, 11, 17}, {2, 0, 1}},
|
||||
{{6, 11, 5, 17}, {3, 2, 0, 1}},
|
||||
};
|
||||
for (const auto& data : populate_data) {
|
||||
Shape shape = ShapeUtil::MakeShapeWithLayout(
|
||||
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
|
||||
data.layout);
|
||||
auto literal = LiteralUtil::CreateFromShape(shape);
|
||||
auto generator = [&](tensorflow::gtl::ArraySlice<int64> indexes) -> uint32 {
|
||||
// Offsets from linear index just to avoid R0 literals to be initialized
|
||||
// with zero.
|
||||
return LiteralUtil::LinearIndex(*literal, indexes) + 17;
|
||||
};
|
||||
TF_EXPECT_OK(LiteralUtil::Populate<uint32>(literal.get(), generator));
|
||||
|
||||
std::vector<int64> zero_base(data.dimensions.size(), 0);
|
||||
std::vector<int64> step(data.dimensions.size(), 1);
|
||||
bool matched = true;
|
||||
auto check_function = [&](const std::vector<int64>& indexes) {
|
||||
auto value = LiteralUtil::Get<uint32>(*literal, indexes);
|
||||
matched = matched && (value == generator(indexes));
|
||||
return matched;
|
||||
};
|
||||
ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
|
||||
check_function);
|
||||
EXPECT_TRUE(matched);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -195,7 +195,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
|
||||
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_THAT(root, op::Constant());
|
||||
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape));
|
||||
|
||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
|
||||
bool matched = true;
|
||||
|
@ -728,9 +728,17 @@ Status ForEachMutableSubshapeHelper(
|
||||
new_shape.add_dimensions(dim);
|
||||
}
|
||||
if (shape.has_layout()) {
|
||||
new_shape.mutable_layout()->clear_minor_to_major();
|
||||
Layout* new_layout = new_shape.mutable_layout();
|
||||
new_layout->clear_minor_to_major();
|
||||
for (auto index : Permute(permutation, shape.layout().minor_to_major())) {
|
||||
new_shape.mutable_layout()->add_minor_to_major(index);
|
||||
new_layout->add_minor_to_major(index);
|
||||
}
|
||||
if (shape.layout().padded_dimensions_size() > 0) {
|
||||
new_layout->clear_padded_dimensions();
|
||||
for (auto dim :
|
||||
Permute(permutation, shape.layout().padded_dimensions())) {
|
||||
new_layout->add_padded_dimensions(dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
return new_shape;
|
||||
@ -1057,7 +1065,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
DCHECK_EQ(count.size(), base.size());
|
||||
const Layout& layout = shape.layout();
|
||||
int64 rank = layout.minor_to_major_size();
|
||||
int64 n = 0;
|
||||
// Allows handling R0 arrays, such that the visitor function will be called
|
||||
// once with the proper empty indexes.
|
||||
int64 n = -1;
|
||||
std::vector<int64> indexes(base.begin(), base.end());
|
||||
while (n < rank && visitor_function(indexes)) {
|
||||
// Increments dimensions in minor to major order.
|
||||
|
@ -153,16 +153,26 @@ string Reindent(tensorflow::StringPiece original,
|
||||
});
|
||||
}
|
||||
|
||||
bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank) {
|
||||
if (rank != permutation.size()) {
|
||||
return false;
|
||||
}
|
||||
std::vector<int64> output(permutation.size(), -1);
|
||||
for (auto index : permutation) {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, rank);
|
||||
output[index] = 0;
|
||||
}
|
||||
return std::find(output.begin(), output.end(), -1) == output.end();
|
||||
}
|
||||
|
||||
std::vector<int64> InversePermutation(
|
||||
tensorflow::gtl::ArraySlice<int64> input_permutation) {
|
||||
DCHECK(IsPermutation(input_permutation, input_permutation.size()));
|
||||
std::vector<int64> output_permutation(input_permutation.size(), -1);
|
||||
for (size_t i = 0; i < input_permutation.size(); ++i) {
|
||||
output_permutation[input_permutation[i]] = i;
|
||||
}
|
||||
DCHECK_EQ(
|
||||
0, std::count(output_permutation.begin(), output_permutation.end(), -1));
|
||||
DCHECK(std::is_permutation(input_permutation.begin(), input_permutation.end(),
|
||||
output_permutation.begin()));
|
||||
return output_permutation;
|
||||
}
|
||||
|
||||
|
@ -177,6 +177,9 @@ Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
|
||||
string Reindent(tensorflow::StringPiece original,
|
||||
tensorflow::StringPiece indentation);
|
||||
|
||||
// Checks whether permutation is a permutation of the [0, rank) integer range.
|
||||
bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
|
||||
|
||||
// Applies `permutation` on `input` and returns the permuted array.
|
||||
// For each i, output[permutation[i]] = input[i].
|
||||
//
|
||||
@ -187,12 +190,11 @@ template <template <typename...> class C, typename T>
|
||||
std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
|
||||
C<T> input_) {
|
||||
tensorflow::gtl::ArraySlice<T> input(input_);
|
||||
CHECK_EQ(permutation.size(), input.size());
|
||||
CHECK(IsPermutation(permutation, input.size()));
|
||||
std::vector<T> output(input.size());
|
||||
for (size_t i = 0; i < permutation.size(); ++i) {
|
||||
output[permutation[i]] = input[i];
|
||||
}
|
||||
DCHECK(std::is_permutation(input.begin(), input.end(), output.begin()));
|
||||
return output;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user