[XLA] Compute source indices in each common factor of reshape from target indices in same common factor

Change: 144124189
This commit is contained in:
A. Unique TensorFlower 2017-01-10 13:54:29 -08:00 committed by TensorFlower Gardener
parent b76344d0ab
commit e2730973b1
6 changed files with 127 additions and 78 deletions

View File

@ -69,7 +69,7 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
dims_(shape.dimensions().begin(), shape.dimensions().end()) {
CHECK_EQ(shape.dimensions_size(), multidim.size());
CHECK(LayoutUtil::HasLayout(shape));
linear_ = Linearize(shape, ir_builder);
linear_ = Linearize(AsInt64Slice(shape.dimensions()), ir_builder);
}
IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape)
@ -109,35 +109,41 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
llvm::IRBuilder<>* builder) const {
const auto& target_index = *this;
CHECK_EQ(target_index.size(), ShapeUtil::Rank(output_shape));
llvm::Value* logical_linear_index = Linearize(output_shape, builder);
// Delinearizes logical_linear_index for the source array in row-major
// collapsed order. The first rank-1 indices are the remainder of the
// linear index by each dimension size.
std::vector<std::pair<int64, int64>> unmodified_dims =
ShapeUtil::DimensionsUnmodifiedByReshape(input_shape, output_shape);
std::vector<llvm::Value*> source_multidim_index(ShapeUtil::Rank(input_shape));
for (int64 i = ShapeUtil::Rank(input_shape) - 1; i >= 0; --i) {
auto divisor = builder->getInt64(input_shape.dimensions(i));
if (input_shape.dimensions(i) <= 1) {
source_multidim_index[i] = builder->getInt64(0);
} else {
// Search unmodified_dims for a pair whose first element is exactly "i".
//
// Because unmodified_dims are sorted by both "first" and "second", and
// "i" is monotonically decreasing, we avoid redundant searching by
// popping the back of unmodified_dims until the rear pair's first element
// <= i. If we stop precisely at "i", we find a match.
while (!unmodified_dims.empty() && unmodified_dims.back().first > i) {
unmodified_dims.pop_back();
}
if (!unmodified_dims.empty() && unmodified_dims.back().first == i) {
source_multidim_index[i] = target_index[unmodified_dims.back().second];
std::vector<std::pair<int64, int64>> common_factors =
CommonFactors(AsInt64Slice(input_shape.dimensions()),
AsInt64Slice(output_shape.dimensions()));
std::vector<llvm::Value*> source_multidim_index(
ShapeUtil::Rank(input_shape),
llvm::UndefValue::get(builder->getInt64Ty()));
// We compute the source indices in each common factor from only the target
// indices in the same common factor.
for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
llvm::Value* logical_linear_index =
Index(tensorflow::gtl::ArraySlice<llvm::Value*>(
multidim_, common_factors[k].second,
common_factors[k + 1].second - common_factors[k].second))
.Linearize(
tensorflow::gtl::ArraySlice<int64>(
AsInt64Slice(output_shape.dimensions()),
common_factors[k].second,
common_factors[k + 1].second - common_factors[k].second),
builder);
// Delinearizes logical_linear_index for the source array in row-major
// collapsed order. The first rank-1 indices are the remainder of the
// linear index by each dimension size.
for (int64 i = common_factors[k + 1].first - 1;
i >= common_factors[k].first; --i) {
llvm::Value* divisor = builder->getInt64(input_shape.dimensions(i));
if (input_shape.dimensions(i) == 1) {
source_multidim_index[i] = builder->getInt64(0);
} else if (i == common_factors[k].first) {
source_multidim_index[i] = logical_linear_index;
} else {
source_multidim_index[i] =
builder->CreateURem(logical_linear_index, divisor);
}
logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor);
}
logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor);
}
if (linear() != nullptr &&
@ -160,8 +166,9 @@ IrArray::Index IrArray::Index::SourceIndexOfTranspose(
return Index(operand_multidim_index);
}
llvm::Value* IrArray::Index::Linearize(const Shape& shape,
llvm::IRBuilder<>* builder) const {
llvm::Value* IrArray::Index::Linearize(
tensorflow::gtl::ArraySlice<int64> dimensions,
llvm::IRBuilder<>* builder) const {
// Each dimension is multiplied by the product of the sizes of all
// earlier dimensions and added to the accumulator logical_linear_index.
llvm::Value* logical_linear_index = builder->getInt64(0);
@ -172,7 +179,7 @@ llvm::Value* IrArray::Index::Linearize(const Shape& shape,
/*HasNUW=*/true, /*HasNSW=*/true);
logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
/*HasNUW=*/true, /*HasNSW=*/true);
multiplier *= shape.dimensions(i);
multiplier *= dimensions[i];
}
return logical_linear_index;
}

View File

@ -124,7 +124,7 @@ class IrArray {
// Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
// returns the index into the sole dimension 0 of the new shape.
llvm::Value* Linearize(const Shape& shape,
llvm::Value* Linearize(tensorflow::gtl::ArraySlice<int64> dimensions,
llvm::IRBuilder<>* builder) const;
private:

View File

@ -767,57 +767,20 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
/* static */ std::vector<std::pair<int64, int64>>
ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
const Shape& output_shape) {
// Returns nil if the input/output shape has zero elements. This is safe but
// might be too conservative. Not a big deal for now because IR emitted for
// zero-element shapes are often trivially optimizable without the help of
// this method.
if (ShapeUtil::ElementsIn(input_shape) == 0 ||
ShapeUtil::ElementsIn(output_shape) == 0) {
return std::vector<std::pair<int64, int64>>();
}
std::vector<std::pair<int64, int64>> unmodified_dims;
int64 input_dim = 0;
int64 output_dim = 0;
// A reshape preserves input_dim as output_dim iff
// 1. input_dim and output_dim have the same size.
// 2. The size of the input subarray from dimension 0 to input_dim-1 equals
// that of the output subarray from dimension 0 to output_dim-1.
VLOG(3) << "DimensionsUnmodifiedByReshape: input_shape="
<< ShapeUtil::HumanString(input_shape)
<< ", output_shape=" << ShapeUtil::HumanString(output_shape);
while (input_dim < ShapeUtil::Rank(input_shape) &&
output_dim < ShapeUtil::Rank(output_shape)) {
// partial_input_size is the product of sizes of input dimensions
// inclusively between the input_dim when this loop iteration starts and the
// current input_dim. partial_output_size is that of output dimensions. We
// compute these two values incrementally to save time.
int64 partial_input_size = input_shape.dimensions(input_dim);
int64 partial_output_size = output_shape.dimensions(output_dim);
// Move input_dim and output_dim forward until
// partial_input_size==partial_output_size.
while (partial_input_size != partial_output_size) {
if (partial_input_size < partial_output_size) {
++input_dim;
partial_input_size *= input_shape.dimensions(input_dim);
} else {
++output_dim;
partial_output_size *= output_shape.dimensions(output_dim);
}
// Unmodified dimensions are merely common factors of rank 1.
auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
AsInt64Slice(output_shape.dimensions()));
for (size_t i = 0; i < common_factors.size() - 1;) {
if (1 != common_factors[i + 1].first - common_factors[i].first ||
1 != common_factors[i + 1].second - common_factors[i].second) {
common_factors.erase(common_factors.begin() + i);
} else {
++i;
}
CHECK_LT(input_dim, ShapeUtil::Rank(input_shape));
CHECK_LT(output_dim, ShapeUtil::Rank(output_shape));
if (input_shape.dimensions(input_dim) ==
output_shape.dimensions(output_dim)) {
unmodified_dims.push_back({input_dim, output_dim});
VLOG(3) << "Matching dimension pair: " << input_dim << ' ' << output_dim;
}
++input_dim;
++output_dim;
}
return unmodified_dims;
// `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
common_factors.pop_back();
return common_factors;
}
/* static */ bool ShapeUtil::TransposeIsBitcast(

View File

@ -235,4 +235,50 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname,
}
}
int64 Product(tensorflow::gtl::ArraySlice<int64> xs) {
return std::accumulate(xs.begin(), xs.end(), 1, std::multiplies<int64>());
}
std::vector<std::pair<int64, int64>> CommonFactors(
tensorflow::gtl::ArraySlice<int64> a,
tensorflow::gtl::ArraySlice<int64> b) {
CHECK_EQ(Product(a), Product(b));
if (0 == Product(a)) {
return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())};
}
std::vector<std::pair<int64, int64>> bounds;
for (int64 i = 0, j = 0, prior_i = -1, prior_j = -1, partial_size_a = 1,
partial_size_b = 1;
;) {
if (partial_size_a == partial_size_b && (i > prior_i || j > prior_j)) {
std::tie(prior_i, prior_j) = std::make_pair(i, j);
bounds.emplace_back(i, j);
continue;
}
bool in_bounds_i = i < a.size();
bool in_bounds_j = j < b.size();
if (!(in_bounds_i || in_bounds_j)) {
break;
}
bool next_a =
partial_size_a < partial_size_b ||
(in_bounds_i &&
(!in_bounds_j || (partial_size_a == partial_size_b && a[i] <= b[j])));
bool next_b =
partial_size_b < partial_size_a ||
(in_bounds_j &&
(!in_bounds_i || (partial_size_b == partial_size_a && b[j] <= a[i])));
if (next_a) {
partial_size_a *= a[i];
++i;
}
if (next_b) {
partial_size_b *= b[j];
++j;
}
}
return bounds;
}
} // namespace xla

View File

@ -238,6 +238,22 @@ std::unique_ptr<Derived> unique_ptr_static_cast(std::unique_ptr<Base> ptr) {
return std::unique_ptr<Derived>(static_cast<Derived*>(ptr.release()));
}
int64 Product(tensorflow::gtl::ArraySlice<int64> xs);
// Returns the start indices of consecutive non-overlapping subsequences of `a`
// and `b` with the same product, i.e. `(i, j)` so
// • a = {a[0 = i_0], ..., a[i_1 - 1], a[i_1], ... , a[i_2 - 1], ...}
// • b = {b[0 = j_0], ..., b[j_1 - 1], b[j_1], ... , b[j_2 - 1], ...}
// • ∀ k . 0 <= k < CommonFactors(a, b).size - 1 =>
// a[i_k] × a[i_k + 1] × ... × a[i_(k+1) - 1] =
// b[j_k] × b[j_k + 1] × ... × b[j_(k+1) - 1]
// where `CommonFactors(a, b)[CommonFactors(a, b).size - 1] = (a.size, b.size)`
//
// If the given shapes have non-zero size, returns the bounds of the shortest
// possible such subsequences; else, returns `{(0, 0), (a.size, b.size)}`.
std::vector<std::pair<int64, int64>> CommonFactors(
tensorflow::gtl::ArraySlice<int64> a, tensorflow::gtl::ArraySlice<int64> b);
} // namespace xla
#define XLA_LOG_LINES(SEV, STRING) LogLines(SEV, STRING, __FILE__, __LINE__)

View File

@ -85,5 +85,22 @@ TEST(UtilTest, LogLines) {
LogLines(tensorflow::INFO, "hello\n\nworld", __FILE__, __LINE__);
}
TEST(UtilTest, CommonFactors) {
struct {
std::vector<int64> a, b;
std::vector<std::pair<int64, int64>> expected;
} test_cases[] = {
{/*.a =*/{0}, /*.b =*/{0}, /*.expected =*/{{0, 0}, {1, 1}}},
{/*.a =*/{}, /*.b =*/{}, /*.expected =*/{{0, 0}}},
{/*.a =*/{2, 5, 1, 3},
/*.b =*/{1, 10, 3, 1},
/*.expected =*/{{0, 0}, {0, 1}, {2, 2}, {3, 2}, {4, 3}, {4, 4}}},
};
for (const auto& test_case : test_cases) {
EXPECT_TRUE(ContainersEqual(test_case.expected,
CommonFactors(test_case.a, test_case.b)));
}
}
} // namespace
} // namespace xla