[XLA] Compute source indices in each common factor of reshape from target indices in same common factor
Change: 144124189
This commit is contained in:
parent
b76344d0ab
commit
e2730973b1
@ -69,7 +69,7 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
|
||||
dims_(shape.dimensions().begin(), shape.dimensions().end()) {
|
||||
CHECK_EQ(shape.dimensions_size(), multidim.size());
|
||||
CHECK(LayoutUtil::HasLayout(shape));
|
||||
linear_ = Linearize(shape, ir_builder);
|
||||
linear_ = Linearize(AsInt64Slice(shape.dimensions()), ir_builder);
|
||||
}
|
||||
|
||||
IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape)
|
||||
@ -109,35 +109,41 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
|
||||
llvm::IRBuilder<>* builder) const {
|
||||
const auto& target_index = *this;
|
||||
CHECK_EQ(target_index.size(), ShapeUtil::Rank(output_shape));
|
||||
llvm::Value* logical_linear_index = Linearize(output_shape, builder);
|
||||
// Delinearizes logical_linear_index for the source array in row-major
|
||||
// collapsed order. The first rank-1 indices are the remainder of the
|
||||
// linear index by each dimension size.
|
||||
std::vector<std::pair<int64, int64>> unmodified_dims =
|
||||
ShapeUtil::DimensionsUnmodifiedByReshape(input_shape, output_shape);
|
||||
std::vector<llvm::Value*> source_multidim_index(ShapeUtil::Rank(input_shape));
|
||||
for (int64 i = ShapeUtil::Rank(input_shape) - 1; i >= 0; --i) {
|
||||
auto divisor = builder->getInt64(input_shape.dimensions(i));
|
||||
if (input_shape.dimensions(i) <= 1) {
|
||||
source_multidim_index[i] = builder->getInt64(0);
|
||||
} else {
|
||||
// Search unmodified_dims for a pair whose first element is exactly "i".
|
||||
//
|
||||
// Because unmodified_dims are sorted by both "first" and "second", and
|
||||
// "i" is monotonically decreasing, we avoid redundant searching by
|
||||
// popping the back of unmodified_dims until the rear pair's first element
|
||||
// <= i. If we stop precisely at "i", we find a match.
|
||||
while (!unmodified_dims.empty() && unmodified_dims.back().first > i) {
|
||||
unmodified_dims.pop_back();
|
||||
}
|
||||
if (!unmodified_dims.empty() && unmodified_dims.back().first == i) {
|
||||
source_multidim_index[i] = target_index[unmodified_dims.back().second];
|
||||
std::vector<std::pair<int64, int64>> common_factors =
|
||||
CommonFactors(AsInt64Slice(input_shape.dimensions()),
|
||||
AsInt64Slice(output_shape.dimensions()));
|
||||
std::vector<llvm::Value*> source_multidim_index(
|
||||
ShapeUtil::Rank(input_shape),
|
||||
llvm::UndefValue::get(builder->getInt64Ty()));
|
||||
// We compute the source indices in each common factor from only the target
|
||||
// indices in the same common factor.
|
||||
for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
|
||||
llvm::Value* logical_linear_index =
|
||||
Index(tensorflow::gtl::ArraySlice<llvm::Value*>(
|
||||
multidim_, common_factors[k].second,
|
||||
common_factors[k + 1].second - common_factors[k].second))
|
||||
.Linearize(
|
||||
tensorflow::gtl::ArraySlice<int64>(
|
||||
AsInt64Slice(output_shape.dimensions()),
|
||||
common_factors[k].second,
|
||||
common_factors[k + 1].second - common_factors[k].second),
|
||||
builder);
|
||||
// Delinearizes logical_linear_index for the source array in row-major
|
||||
// collapsed order. The first rank-1 indices are the remainder of the
|
||||
// linear index by each dimension size.
|
||||
for (int64 i = common_factors[k + 1].first - 1;
|
||||
i >= common_factors[k].first; --i) {
|
||||
llvm::Value* divisor = builder->getInt64(input_shape.dimensions(i));
|
||||
if (input_shape.dimensions(i) == 1) {
|
||||
source_multidim_index[i] = builder->getInt64(0);
|
||||
} else if (i == common_factors[k].first) {
|
||||
source_multidim_index[i] = logical_linear_index;
|
||||
} else {
|
||||
source_multidim_index[i] =
|
||||
builder->CreateURem(logical_linear_index, divisor);
|
||||
}
|
||||
logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor);
|
||||
}
|
||||
logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor);
|
||||
}
|
||||
|
||||
if (linear() != nullptr &&
|
||||
@ -160,8 +166,9 @@ IrArray::Index IrArray::Index::SourceIndexOfTranspose(
|
||||
return Index(operand_multidim_index);
|
||||
}
|
||||
|
||||
llvm::Value* IrArray::Index::Linearize(const Shape& shape,
|
||||
llvm::IRBuilder<>* builder) const {
|
||||
llvm::Value* IrArray::Index::Linearize(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
llvm::IRBuilder<>* builder) const {
|
||||
// Each dimension is multiplied by the product of the sizes of all
|
||||
// earlier dimensions and added to the accumulator logical_linear_index.
|
||||
llvm::Value* logical_linear_index = builder->getInt64(0);
|
||||
@ -172,7 +179,7 @@ llvm::Value* IrArray::Index::Linearize(const Shape& shape,
|
||||
/*HasNUW=*/true, /*HasNSW=*/true);
|
||||
logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
|
||||
/*HasNUW=*/true, /*HasNSW=*/true);
|
||||
multiplier *= shape.dimensions(i);
|
||||
multiplier *= dimensions[i];
|
||||
}
|
||||
return logical_linear_index;
|
||||
}
|
||||
|
@ -124,7 +124,7 @@ class IrArray {
|
||||
|
||||
// Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
|
||||
// returns the index into the sole dimension 0 of the new shape.
|
||||
llvm::Value* Linearize(const Shape& shape,
|
||||
llvm::Value* Linearize(tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
llvm::IRBuilder<>* builder) const;
|
||||
|
||||
private:
|
||||
|
@ -767,57 +767,20 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
|
||||
/* static */ std::vector<std::pair<int64, int64>>
|
||||
ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
const Shape& output_shape) {
|
||||
// Returns nil if the input/output shape has zero elements. This is safe but
|
||||
// might be too conservative. Not a big deal for now because IR emitted for
|
||||
// zero-element shapes are often trivially optimizable without the help of
|
||||
// this method.
|
||||
if (ShapeUtil::ElementsIn(input_shape) == 0 ||
|
||||
ShapeUtil::ElementsIn(output_shape) == 0) {
|
||||
return std::vector<std::pair<int64, int64>>();
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64, int64>> unmodified_dims;
|
||||
int64 input_dim = 0;
|
||||
int64 output_dim = 0;
|
||||
|
||||
// A reshape preserves input_dim as output_dim iff
|
||||
// 1. input_dim and output_dim have the same size.
|
||||
// 2. The size of the input subarray from dimension 0 to input_dim-1 equals
|
||||
// that of the output subarray from dimension 0 to output_dim-1.
|
||||
VLOG(3) << "DimensionsUnmodifiedByReshape: input_shape="
|
||||
<< ShapeUtil::HumanString(input_shape)
|
||||
<< ", output_shape=" << ShapeUtil::HumanString(output_shape);
|
||||
while (input_dim < ShapeUtil::Rank(input_shape) &&
|
||||
output_dim < ShapeUtil::Rank(output_shape)) {
|
||||
// partial_input_size is the product of sizes of input dimensions
|
||||
// inclusively between the input_dim when this loop iteration starts and the
|
||||
// current input_dim. partial_output_size is that of output dimensions. We
|
||||
// compute these two values incrementally to save time.
|
||||
int64 partial_input_size = input_shape.dimensions(input_dim);
|
||||
int64 partial_output_size = output_shape.dimensions(output_dim);
|
||||
// Move input_dim and output_dim forward until
|
||||
// partial_input_size==partial_output_size.
|
||||
while (partial_input_size != partial_output_size) {
|
||||
if (partial_input_size < partial_output_size) {
|
||||
++input_dim;
|
||||
partial_input_size *= input_shape.dimensions(input_dim);
|
||||
} else {
|
||||
++output_dim;
|
||||
partial_output_size *= output_shape.dimensions(output_dim);
|
||||
}
|
||||
// Unmodified dimensions are merely common factors of rank 1.
|
||||
auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
|
||||
AsInt64Slice(output_shape.dimensions()));
|
||||
for (size_t i = 0; i < common_factors.size() - 1;) {
|
||||
if (1 != common_factors[i + 1].first - common_factors[i].first ||
|
||||
1 != common_factors[i + 1].second - common_factors[i].second) {
|
||||
common_factors.erase(common_factors.begin() + i);
|
||||
} else {
|
||||
++i;
|
||||
}
|
||||
CHECK_LT(input_dim, ShapeUtil::Rank(input_shape));
|
||||
CHECK_LT(output_dim, ShapeUtil::Rank(output_shape));
|
||||
if (input_shape.dimensions(input_dim) ==
|
||||
output_shape.dimensions(output_dim)) {
|
||||
unmodified_dims.push_back({input_dim, output_dim});
|
||||
VLOG(3) << "Matching dimension pair: " << input_dim << ' ' << output_dim;
|
||||
}
|
||||
++input_dim;
|
||||
++output_dim;
|
||||
}
|
||||
|
||||
return unmodified_dims;
|
||||
// `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
|
||||
common_factors.pop_back();
|
||||
return common_factors;
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::TransposeIsBitcast(
|
||||
|
@ -235,4 +235,50 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname,
|
||||
}
|
||||
}
|
||||
|
||||
int64 Product(tensorflow::gtl::ArraySlice<int64> xs) {
|
||||
return std::accumulate(xs.begin(), xs.end(), 1, std::multiplies<int64>());
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64, int64>> CommonFactors(
|
||||
tensorflow::gtl::ArraySlice<int64> a,
|
||||
tensorflow::gtl::ArraySlice<int64> b) {
|
||||
CHECK_EQ(Product(a), Product(b));
|
||||
if (0 == Product(a)) {
|
||||
return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())};
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64, int64>> bounds;
|
||||
for (int64 i = 0, j = 0, prior_i = -1, prior_j = -1, partial_size_a = 1,
|
||||
partial_size_b = 1;
|
||||
;) {
|
||||
if (partial_size_a == partial_size_b && (i > prior_i || j > prior_j)) {
|
||||
std::tie(prior_i, prior_j) = std::make_pair(i, j);
|
||||
bounds.emplace_back(i, j);
|
||||
continue;
|
||||
}
|
||||
bool in_bounds_i = i < a.size();
|
||||
bool in_bounds_j = j < b.size();
|
||||
if (!(in_bounds_i || in_bounds_j)) {
|
||||
break;
|
||||
}
|
||||
bool next_a =
|
||||
partial_size_a < partial_size_b ||
|
||||
(in_bounds_i &&
|
||||
(!in_bounds_j || (partial_size_a == partial_size_b && a[i] <= b[j])));
|
||||
bool next_b =
|
||||
partial_size_b < partial_size_a ||
|
||||
(in_bounds_j &&
|
||||
(!in_bounds_i || (partial_size_b == partial_size_a && b[j] <= a[i])));
|
||||
if (next_a) {
|
||||
partial_size_a *= a[i];
|
||||
++i;
|
||||
}
|
||||
if (next_b) {
|
||||
partial_size_b *= b[j];
|
||||
++j;
|
||||
}
|
||||
}
|
||||
return bounds;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -238,6 +238,22 @@ std::unique_ptr<Derived> unique_ptr_static_cast(std::unique_ptr<Base> ptr) {
|
||||
return std::unique_ptr<Derived>(static_cast<Derived*>(ptr.release()));
|
||||
}
|
||||
|
||||
int64 Product(tensorflow::gtl::ArraySlice<int64> xs);
|
||||
|
||||
// Returns the start indices of consecutive non-overlapping subsequences of `a`
|
||||
// and `b` with the same product, i.e. `(i, j)` so
|
||||
// • a = {a[0 = i_0], ..., a[i_1 - 1], a[i_1], ... , a[i_2 - 1], ...}
|
||||
// • b = {b[0 = j_0], ..., b[j_1 - 1], b[j_1], ... , b[j_2 - 1], ...}
|
||||
// • ∀ k . 0 <= k < CommonFactors(a, b).size - 1 =>
|
||||
// a[i_k] × a[i_k + 1] × ... × a[i_(k+1) - 1] =
|
||||
// b[j_k] × b[j_k + 1] × ... × b[j_(k+1) - 1]
|
||||
// where `CommonFactors(a, b)[CommonFactors(a, b).size - 1] = (a.size, b.size)`
|
||||
//
|
||||
// If the given shapes have non-zero size, returns the bounds of the shortest
|
||||
// possible such subsequences; else, returns `{(0, 0), (a.size, b.size)}`.
|
||||
std::vector<std::pair<int64, int64>> CommonFactors(
|
||||
tensorflow::gtl::ArraySlice<int64> a, tensorflow::gtl::ArraySlice<int64> b);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#define XLA_LOG_LINES(SEV, STRING) LogLines(SEV, STRING, __FILE__, __LINE__)
|
||||
|
@ -85,5 +85,22 @@ TEST(UtilTest, LogLines) {
|
||||
LogLines(tensorflow::INFO, "hello\n\nworld", __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
TEST(UtilTest, CommonFactors) {
|
||||
struct {
|
||||
std::vector<int64> a, b;
|
||||
std::vector<std::pair<int64, int64>> expected;
|
||||
} test_cases[] = {
|
||||
{/*.a =*/{0}, /*.b =*/{0}, /*.expected =*/{{0, 0}, {1, 1}}},
|
||||
{/*.a =*/{}, /*.b =*/{}, /*.expected =*/{{0, 0}}},
|
||||
{/*.a =*/{2, 5, 1, 3},
|
||||
/*.b =*/{1, 10, 3, 1},
|
||||
/*.expected =*/{{0, 0}, {0, 1}, {2, 2}, {3, 2}, {4, 3}, {4, 4}}},
|
||||
};
|
||||
for (const auto& test_case : test_cases) {
|
||||
EXPECT_TRUE(ContainersEqual(test_case.expected,
|
||||
CommonFactors(test_case.a, test_case.b)));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user