[XLA] Compute source indices in each common factor of reshape from target indices in same common factor
Change: 144124189
This commit is contained in:
parent
b76344d0ab
commit
e2730973b1
@ -69,7 +69,7 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
|
|||||||
dims_(shape.dimensions().begin(), shape.dimensions().end()) {
|
dims_(shape.dimensions().begin(), shape.dimensions().end()) {
|
||||||
CHECK_EQ(shape.dimensions_size(), multidim.size());
|
CHECK_EQ(shape.dimensions_size(), multidim.size());
|
||||||
CHECK(LayoutUtil::HasLayout(shape));
|
CHECK(LayoutUtil::HasLayout(shape));
|
||||||
linear_ = Linearize(shape, ir_builder);
|
linear_ = Linearize(AsInt64Slice(shape.dimensions()), ir_builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape)
|
IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape)
|
||||||
@ -109,35 +109,41 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
|
|||||||
llvm::IRBuilder<>* builder) const {
|
llvm::IRBuilder<>* builder) const {
|
||||||
const auto& target_index = *this;
|
const auto& target_index = *this;
|
||||||
CHECK_EQ(target_index.size(), ShapeUtil::Rank(output_shape));
|
CHECK_EQ(target_index.size(), ShapeUtil::Rank(output_shape));
|
||||||
llvm::Value* logical_linear_index = Linearize(output_shape, builder);
|
std::vector<std::pair<int64, int64>> common_factors =
|
||||||
// Delinearizes logical_linear_index for the source array in row-major
|
CommonFactors(AsInt64Slice(input_shape.dimensions()),
|
||||||
// collapsed order. The first rank-1 indices are the remainder of the
|
AsInt64Slice(output_shape.dimensions()));
|
||||||
// linear index by each dimension size.
|
std::vector<llvm::Value*> source_multidim_index(
|
||||||
std::vector<std::pair<int64, int64>> unmodified_dims =
|
ShapeUtil::Rank(input_shape),
|
||||||
ShapeUtil::DimensionsUnmodifiedByReshape(input_shape, output_shape);
|
llvm::UndefValue::get(builder->getInt64Ty()));
|
||||||
std::vector<llvm::Value*> source_multidim_index(ShapeUtil::Rank(input_shape));
|
// We compute the source indices in each common factor from only the target
|
||||||
for (int64 i = ShapeUtil::Rank(input_shape) - 1; i >= 0; --i) {
|
// indices in the same common factor.
|
||||||
auto divisor = builder->getInt64(input_shape.dimensions(i));
|
for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
|
||||||
if (input_shape.dimensions(i) <= 1) {
|
llvm::Value* logical_linear_index =
|
||||||
source_multidim_index[i] = builder->getInt64(0);
|
Index(tensorflow::gtl::ArraySlice<llvm::Value*>(
|
||||||
} else {
|
multidim_, common_factors[k].second,
|
||||||
// Search unmodified_dims for a pair whose first element is exactly "i".
|
common_factors[k + 1].second - common_factors[k].second))
|
||||||
//
|
.Linearize(
|
||||||
// Because unmodified_dims are sorted by both "first" and "second", and
|
tensorflow::gtl::ArraySlice<int64>(
|
||||||
// "i" is monotonically decreasing, we avoid redundant searching by
|
AsInt64Slice(output_shape.dimensions()),
|
||||||
// popping the back of unmodified_dims until the rear pair's first element
|
common_factors[k].second,
|
||||||
// <= i. If we stop precisely at "i", we find a match.
|
common_factors[k + 1].second - common_factors[k].second),
|
||||||
while (!unmodified_dims.empty() && unmodified_dims.back().first > i) {
|
builder);
|
||||||
unmodified_dims.pop_back();
|
// Delinearizes logical_linear_index for the source array in row-major
|
||||||
}
|
// collapsed order. The first rank-1 indices are the remainder of the
|
||||||
if (!unmodified_dims.empty() && unmodified_dims.back().first == i) {
|
// linear index by each dimension size.
|
||||||
source_multidim_index[i] = target_index[unmodified_dims.back().second];
|
for (int64 i = common_factors[k + 1].first - 1;
|
||||||
|
i >= common_factors[k].first; --i) {
|
||||||
|
llvm::Value* divisor = builder->getInt64(input_shape.dimensions(i));
|
||||||
|
if (input_shape.dimensions(i) == 1) {
|
||||||
|
source_multidim_index[i] = builder->getInt64(0);
|
||||||
|
} else if (i == common_factors[k].first) {
|
||||||
|
source_multidim_index[i] = logical_linear_index;
|
||||||
} else {
|
} else {
|
||||||
source_multidim_index[i] =
|
source_multidim_index[i] =
|
||||||
builder->CreateURem(logical_linear_index, divisor);
|
builder->CreateURem(logical_linear_index, divisor);
|
||||||
}
|
}
|
||||||
|
logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor);
|
||||||
}
|
}
|
||||||
logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (linear() != nullptr &&
|
if (linear() != nullptr &&
|
||||||
@ -160,8 +166,9 @@ IrArray::Index IrArray::Index::SourceIndexOfTranspose(
|
|||||||
return Index(operand_multidim_index);
|
return Index(operand_multidim_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Value* IrArray::Index::Linearize(const Shape& shape,
|
llvm::Value* IrArray::Index::Linearize(
|
||||||
llvm::IRBuilder<>* builder) const {
|
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||||
|
llvm::IRBuilder<>* builder) const {
|
||||||
// Each dimension is multiplied by the product of the sizes of all
|
// Each dimension is multiplied by the product of the sizes of all
|
||||||
// earlier dimensions and added to the accumulator logical_linear_index.
|
// earlier dimensions and added to the accumulator logical_linear_index.
|
||||||
llvm::Value* logical_linear_index = builder->getInt64(0);
|
llvm::Value* logical_linear_index = builder->getInt64(0);
|
||||||
@ -172,7 +179,7 @@ llvm::Value* IrArray::Index::Linearize(const Shape& shape,
|
|||||||
/*HasNUW=*/true, /*HasNSW=*/true);
|
/*HasNUW=*/true, /*HasNSW=*/true);
|
||||||
logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
|
logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
|
||||||
/*HasNUW=*/true, /*HasNSW=*/true);
|
/*HasNUW=*/true, /*HasNSW=*/true);
|
||||||
multiplier *= shape.dimensions(i);
|
multiplier *= dimensions[i];
|
||||||
}
|
}
|
||||||
return logical_linear_index;
|
return logical_linear_index;
|
||||||
}
|
}
|
||||||
|
@ -124,7 +124,7 @@ class IrArray {
|
|||||||
|
|
||||||
// Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
|
// Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
|
||||||
// returns the index into the sole dimension 0 of the new shape.
|
// returns the index into the sole dimension 0 of the new shape.
|
||||||
llvm::Value* Linearize(const Shape& shape,
|
llvm::Value* Linearize(tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||||
llvm::IRBuilder<>* builder) const;
|
llvm::IRBuilder<>* builder) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -767,57 +767,20 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
|
|||||||
/* static */ std::vector<std::pair<int64, int64>>
|
/* static */ std::vector<std::pair<int64, int64>>
|
||||||
ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||||
const Shape& output_shape) {
|
const Shape& output_shape) {
|
||||||
// Returns nil if the input/output shape has zero elements. This is safe but
|
// Unmodified dimensions are merely common factors of rank 1.
|
||||||
// might be too conservative. Not a big deal for now because IR emitted for
|
auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
|
||||||
// zero-element shapes are often trivially optimizable without the help of
|
AsInt64Slice(output_shape.dimensions()));
|
||||||
// this method.
|
for (size_t i = 0; i < common_factors.size() - 1;) {
|
||||||
if (ShapeUtil::ElementsIn(input_shape) == 0 ||
|
if (1 != common_factors[i + 1].first - common_factors[i].first ||
|
||||||
ShapeUtil::ElementsIn(output_shape) == 0) {
|
1 != common_factors[i + 1].second - common_factors[i].second) {
|
||||||
return std::vector<std::pair<int64, int64>>();
|
common_factors.erase(common_factors.begin() + i);
|
||||||
}
|
} else {
|
||||||
|
++i;
|
||||||
std::vector<std::pair<int64, int64>> unmodified_dims;
|
|
||||||
int64 input_dim = 0;
|
|
||||||
int64 output_dim = 0;
|
|
||||||
|
|
||||||
// A reshape preserves input_dim as output_dim iff
|
|
||||||
// 1. input_dim and output_dim have the same size.
|
|
||||||
// 2. The size of the input subarray from dimension 0 to input_dim-1 equals
|
|
||||||
// that of the output subarray from dimension 0 to output_dim-1.
|
|
||||||
VLOG(3) << "DimensionsUnmodifiedByReshape: input_shape="
|
|
||||||
<< ShapeUtil::HumanString(input_shape)
|
|
||||||
<< ", output_shape=" << ShapeUtil::HumanString(output_shape);
|
|
||||||
while (input_dim < ShapeUtil::Rank(input_shape) &&
|
|
||||||
output_dim < ShapeUtil::Rank(output_shape)) {
|
|
||||||
// partial_input_size is the product of sizes of input dimensions
|
|
||||||
// inclusively between the input_dim when this loop iteration starts and the
|
|
||||||
// current input_dim. partial_output_size is that of output dimensions. We
|
|
||||||
// compute these two values incrementally to save time.
|
|
||||||
int64 partial_input_size = input_shape.dimensions(input_dim);
|
|
||||||
int64 partial_output_size = output_shape.dimensions(output_dim);
|
|
||||||
// Move input_dim and output_dim forward until
|
|
||||||
// partial_input_size==partial_output_size.
|
|
||||||
while (partial_input_size != partial_output_size) {
|
|
||||||
if (partial_input_size < partial_output_size) {
|
|
||||||
++input_dim;
|
|
||||||
partial_input_size *= input_shape.dimensions(input_dim);
|
|
||||||
} else {
|
|
||||||
++output_dim;
|
|
||||||
partial_output_size *= output_shape.dimensions(output_dim);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
CHECK_LT(input_dim, ShapeUtil::Rank(input_shape));
|
|
||||||
CHECK_LT(output_dim, ShapeUtil::Rank(output_shape));
|
|
||||||
if (input_shape.dimensions(input_dim) ==
|
|
||||||
output_shape.dimensions(output_dim)) {
|
|
||||||
unmodified_dims.push_back({input_dim, output_dim});
|
|
||||||
VLOG(3) << "Matching dimension pair: " << input_dim << ' ' << output_dim;
|
|
||||||
}
|
|
||||||
++input_dim;
|
|
||||||
++output_dim;
|
|
||||||
}
|
}
|
||||||
|
// `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
|
||||||
return unmodified_dims;
|
common_factors.pop_back();
|
||||||
|
return common_factors;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ bool ShapeUtil::TransposeIsBitcast(
|
/* static */ bool ShapeUtil::TransposeIsBitcast(
|
||||||
|
@ -235,4 +235,50 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 Product(tensorflow::gtl::ArraySlice<int64> xs) {
|
||||||
|
return std::accumulate(xs.begin(), xs.end(), 1, std::multiplies<int64>());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<int64, int64>> CommonFactors(
|
||||||
|
tensorflow::gtl::ArraySlice<int64> a,
|
||||||
|
tensorflow::gtl::ArraySlice<int64> b) {
|
||||||
|
CHECK_EQ(Product(a), Product(b));
|
||||||
|
if (0 == Product(a)) {
|
||||||
|
return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<int64, int64>> bounds;
|
||||||
|
for (int64 i = 0, j = 0, prior_i = -1, prior_j = -1, partial_size_a = 1,
|
||||||
|
partial_size_b = 1;
|
||||||
|
;) {
|
||||||
|
if (partial_size_a == partial_size_b && (i > prior_i || j > prior_j)) {
|
||||||
|
std::tie(prior_i, prior_j) = std::make_pair(i, j);
|
||||||
|
bounds.emplace_back(i, j);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
bool in_bounds_i = i < a.size();
|
||||||
|
bool in_bounds_j = j < b.size();
|
||||||
|
if (!(in_bounds_i || in_bounds_j)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
bool next_a =
|
||||||
|
partial_size_a < partial_size_b ||
|
||||||
|
(in_bounds_i &&
|
||||||
|
(!in_bounds_j || (partial_size_a == partial_size_b && a[i] <= b[j])));
|
||||||
|
bool next_b =
|
||||||
|
partial_size_b < partial_size_a ||
|
||||||
|
(in_bounds_j &&
|
||||||
|
(!in_bounds_i || (partial_size_b == partial_size_a && b[j] <= a[i])));
|
||||||
|
if (next_a) {
|
||||||
|
partial_size_a *= a[i];
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
if (next_b) {
|
||||||
|
partial_size_b *= b[j];
|
||||||
|
++j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bounds;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -238,6 +238,22 @@ std::unique_ptr<Derived> unique_ptr_static_cast(std::unique_ptr<Base> ptr) {
|
|||||||
return std::unique_ptr<Derived>(static_cast<Derived*>(ptr.release()));
|
return std::unique_ptr<Derived>(static_cast<Derived*>(ptr.release()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 Product(tensorflow::gtl::ArraySlice<int64> xs);
|
||||||
|
|
||||||
|
// Returns the start indices of consecutive non-overlapping subsequences of `a`
|
||||||
|
// and `b` with the same product, i.e. `(i, j)` so
|
||||||
|
// • a = {a[0 = i_0], ..., a[i_1 - 1], a[i_1], ... , a[i_2 - 1], ...}
|
||||||
|
// • b = {b[0 = j_0], ..., b[j_1 - 1], b[j_1], ... , b[j_2 - 1], ...}
|
||||||
|
// • ∀ k . 0 <= k < CommonFactors(a, b).size - 1 =>
|
||||||
|
// a[i_k] × a[i_k + 1] × ... × a[i_(k+1) - 1] =
|
||||||
|
// b[j_k] × b[j_k + 1] × ... × b[j_(k+1) - 1]
|
||||||
|
// where `CommonFactors(a, b)[CommonFactors(a, b).size - 1] = (a.size, b.size)`
|
||||||
|
//
|
||||||
|
// If the given shapes have non-zero size, returns the bounds of the shortest
|
||||||
|
// possible such subsequences; else, returns `{(0, 0), (a.size, b.size)}`.
|
||||||
|
std::vector<std::pair<int64, int64>> CommonFactors(
|
||||||
|
tensorflow::gtl::ArraySlice<int64> a, tensorflow::gtl::ArraySlice<int64> b);
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#define XLA_LOG_LINES(SEV, STRING) LogLines(SEV, STRING, __FILE__, __LINE__)
|
#define XLA_LOG_LINES(SEV, STRING) LogLines(SEV, STRING, __FILE__, __LINE__)
|
||||||
|
@ -85,5 +85,22 @@ TEST(UtilTest, LogLines) {
|
|||||||
LogLines(tensorflow::INFO, "hello\n\nworld", __FILE__, __LINE__);
|
LogLines(tensorflow::INFO, "hello\n\nworld", __FILE__, __LINE__);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(UtilTest, CommonFactors) {
|
||||||
|
struct {
|
||||||
|
std::vector<int64> a, b;
|
||||||
|
std::vector<std::pair<int64, int64>> expected;
|
||||||
|
} test_cases[] = {
|
||||||
|
{/*.a =*/{0}, /*.b =*/{0}, /*.expected =*/{{0, 0}, {1, 1}}},
|
||||||
|
{/*.a =*/{}, /*.b =*/{}, /*.expected =*/{{0, 0}}},
|
||||||
|
{/*.a =*/{2, 5, 1, 3},
|
||||||
|
/*.b =*/{1, 10, 3, 1},
|
||||||
|
/*.expected =*/{{0, 0}, {0, 1}, {2, 2}, {3, 2}, {4, 3}, {4, 4}}},
|
||||||
|
};
|
||||||
|
for (const auto& test_case : test_cases) {
|
||||||
|
EXPECT_TRUE(ContainersEqual(test_case.expected,
|
||||||
|
CommonFactors(test_case.a, test_case.b)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user