Extend the Array class with more functionality.
PiperOrigin-RevId: 175277161
This commit is contained in:
parent
8d46b72fdc
commit
593dfb6a34
tensorflow/compiler/xla
@ -340,6 +340,7 @@ cc_library(
|
||||
name = "array",
|
||||
hdrs = ["array.h"],
|
||||
deps = [
|
||||
":status",
|
||||
":types",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
|
@ -23,8 +23,10 @@ limitations under the License.
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/bits.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
@ -35,10 +37,63 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace array_impl {
|
||||
|
||||
// conjunction
|
||||
//
|
||||
// Performs a compile-time logical AND operation on the passed types (which
|
||||
// must have `::value` members convertible to `bool`. Short-circuits if it
|
||||
// encounters any `false` members (and does not compare the `::value` members
|
||||
// of any remaining arguments).
|
||||
//
|
||||
// This metafunction is designed to be a drop-in replacement for the C++17
|
||||
// `std::conjunction` metafunction.
|
||||
template <typename... Ts>
|
||||
struct conjunction;
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
struct conjunction<T, Ts...>
|
||||
: std::conditional<T::value, conjunction<Ts...>, T>::type {};
|
||||
|
||||
template <>
|
||||
struct conjunction<> : std::true_type {};
|
||||
|
||||
// A type trait that is valid when all elements in a parameter pack are of
|
||||
// integral type.
|
||||
template <typename... T>
|
||||
using pack_is_integral = conjunction<std::is_integral<T>...>;
|
||||
|
||||
// Compares three same-sized vectors elementwise. For each item in `values`,
|
||||
// returns false if any of values[i] is outside the half-open range [starts[i],
|
||||
// ends[i]).
|
||||
template <typename C1, typename C2, typename C3>
|
||||
bool all_inside_range(const C1& values, const C2& range_starts,
|
||||
const C3& range_ends) {
|
||||
for (size_t i = 0, e = values.size(); i < e; ++i) {
|
||||
if (values[i] < range_starts[i] || values[i] >= range_ends[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace array_impl
|
||||
|
||||
// General N dimensional array class with arbitrary value type.
|
||||
template <typename T>
|
||||
class Array {
|
||||
public:
|
||||
// Type inference can have a hard time parsing very deep initializer list
|
||||
// nests, especially if one or more dimensions is one as the compiler just
|
||||
// sees a single-element integer initializer. These typedefs allow casting
|
||||
// explicitly with less typing.
|
||||
using InitializerList1D = std::initializer_list<T>;
|
||||
using InitializerList2D = std::initializer_list<InitializerList1D>;
|
||||
using InitializerList3D = std::initializer_list<InitializerList2D>;
|
||||
using InitializerList4D = std::initializer_list<InitializerList3D>;
|
||||
|
||||
using value_type = T;
|
||||
|
||||
// Creates a new array with the specified dimensions.
|
||||
explicit Array(tensorflow::gtl::ArraySlice<int64> sizes)
|
||||
: Array(sizes, T()) {}
|
||||
@ -53,7 +108,7 @@ class Array {
|
||||
// Creates a 2D array from the given nested initializer list. The outer
|
||||
// initializer list is the first dimension, the inner is the second dimension.
|
||||
// For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3.
|
||||
Array(std::initializer_list<std::initializer_list<T>> values)
|
||||
Array(InitializerList2D values)
|
||||
: Array(ToInt64Vector({values.size(), values.begin()->size()})) {
|
||||
int64 idx = 0;
|
||||
for (const auto& it1 : values) {
|
||||
@ -67,8 +122,7 @@ class Array {
|
||||
|
||||
// Creates a 3D array from the given nested initializer list. The outer
|
||||
// initializer list is the first dimension, and so on.
|
||||
Array(std::initializer_list<std::initializer_list<std::initializer_list<T>>>
|
||||
values)
|
||||
Array(InitializerList3D values)
|
||||
: Array(ToInt64Vector({values.size(), values.begin()->size(),
|
||||
values.begin()->begin()->size()})) {
|
||||
int64 idx = 0;
|
||||
@ -85,9 +139,7 @@ class Array {
|
||||
|
||||
// Creates a 4D array from the given nested initializer list. The outer
|
||||
// initializer list is the first dimension, and so on.
|
||||
Array(std::initializer_list<
|
||||
std::initializer_list<std::initializer_list<std::initializer_list<T>>>>
|
||||
values)
|
||||
Array(InitializerList4D values)
|
||||
: Array(ToInt64Vector({values.size(), values.begin()->size(),
|
||||
values.begin()->begin()->size(),
|
||||
values.begin()->begin()->begin()->size()})) {
|
||||
@ -173,10 +225,46 @@ class Array {
|
||||
}
|
||||
}
|
||||
|
||||
// Invokes a callback with the (indices, value_ptr) for each cell in the
|
||||
// array. If a callback returns a non-OK status, returns that else returns
|
||||
// Status::OK().
|
||||
Status EachStatus(
|
||||
std::function<Status(tensorflow::gtl::ArraySlice<int64>, T*)> f) {
|
||||
std::vector<int64> index(sizes_.size());
|
||||
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
|
||||
Status s = f(index, &values_[i]);
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Invokes a callback with the (indices, value) for each cell in the array.
|
||||
// If a callback returns a non-OK status, returns that else returns
|
||||
// Status::OK().
|
||||
Status EachStatus(
|
||||
std::function<Status(tensorflow::gtl::ArraySlice<int64>, T)> f) const {
|
||||
std::vector<int64> index(sizes_.size());
|
||||
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
|
||||
Status s = f(index, values_[i]);
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns the value at the cell specified by the indexes. The number of
|
||||
// arguments have to match with the number of dimensions for the array.
|
||||
//
|
||||
// The type trait is required to avoid this overload participating too
|
||||
// eagerly; a parameter pack can take zero or more elements, so we must
|
||||
// restrict this to only parameter packs that are all of integral type.
|
||||
template <typename... Dims>
|
||||
const T& operator()(Dims... dims) const {
|
||||
typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
|
||||
const T&>::type
|
||||
operator()(Dims... dims) const {
|
||||
// We are using a std::array to avoid having to allocate memory in this
|
||||
// function for performance reasons.
|
||||
std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
|
||||
@ -186,7 +274,9 @@ class Array {
|
||||
// Returns the value at the cell specified by the indexes. The number of
|
||||
// arguments have to match with the number of dimensions for the array.
|
||||
template <typename... Dims>
|
||||
T& operator()(Dims... dims) {
|
||||
typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
|
||||
T&>::type
|
||||
operator()(Dims... dims) {
|
||||
// We are using a std::array to avoid having to allocate memory in this
|
||||
// function for performance reasons.
|
||||
std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
|
||||
@ -255,6 +345,59 @@ class Array {
|
||||
|
||||
bool operator!=(const Array<T>& other) const { return !(*this == other); }
|
||||
|
||||
// Performs the equivalent of a slice operation on this array.
|
||||
Array<T> Slice(tensorflow::gtl::ArraySlice<int64> starts,
|
||||
tensorflow::gtl::ArraySlice<int64> limits) const {
|
||||
CHECK_EQ(starts.size(), num_dimensions());
|
||||
CHECK_EQ(limits.size(), num_dimensions());
|
||||
|
||||
std::vector<int64> sizes;
|
||||
std::transform(starts.begin(), starts.end(), limits.begin(),
|
||||
std::back_inserter(sizes),
|
||||
[](int64 start, int64 limit) { return limit - start; });
|
||||
Array<T> result(sizes);
|
||||
|
||||
std::vector<int64> index(sizes_.size());
|
||||
int64 slice_i = 0;
|
||||
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
|
||||
if (array_impl::all_inside_range(index, starts, limits)) {
|
||||
// Even though the bounds of result are different to our bounds, we're
|
||||
// iterating in the same order. So we can simply write successive linear
|
||||
// indices instead of recalculating a multi-dimensional index.
|
||||
result.values_[slice_i++] = values_[i];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Performs the equivalent of a DynamicUpdateSlice in-place on this array.
|
||||
void UpdateSlice(const Array<T>& from,
|
||||
tensorflow::gtl::ArraySlice<int64> start_indices) {
|
||||
CHECK_EQ(from.num_dimensions(), num_dimensions());
|
||||
std::vector<int64> limit_indices;
|
||||
std::transform(start_indices.begin(), start_indices.end(),
|
||||
from.dimensions().begin(), std::back_inserter(limit_indices),
|
||||
std::plus<int64>{});
|
||||
std::vector<int64> index(sizes_.size());
|
||||
int64 from_i = 0;
|
||||
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
|
||||
if (array_impl::all_inside_range(index, start_indices, limit_indices)) {
|
||||
// Even though the bounds of from are different to our bounds, we're
|
||||
// iterating in the same order. So we can simply write successive linear
|
||||
// indices instead of recalculating a multi-dimensional index.
|
||||
values_[i] = from.values_[from_i++];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Performs an in-place reshape, modifying the dimensions but not the
|
||||
// underlying data.
|
||||
void Reshape(tensorflow::gtl::ArraySlice<int64> new_dimensions) {
|
||||
int64 old_num_elements = num_elements();
|
||||
sizes_ = std::vector<int64>(new_dimensions.begin(), new_dimensions.end());
|
||||
CHECK_EQ(num_elements(), old_num_elements);
|
||||
}
|
||||
|
||||
// Returns a string representation of the array suitable for debugging.
|
||||
string ToString() const {
|
||||
std::vector<string> pieces;
|
||||
|
@ -71,6 +71,19 @@ TEST(ArrayTest, IndexingReadWrite) {
|
||||
EXPECT_EQ(arr(1, 2), 61);
|
||||
}
|
||||
|
||||
TEST(ArrayTest, DynamicIndexingReadWrite) {
|
||||
Array<int> arr({2, 3});
|
||||
|
||||
std::vector<int64> index1 = {1, 1};
|
||||
std::vector<int64> index2 = {1, 2};
|
||||
EXPECT_EQ(arr(index1), 0);
|
||||
EXPECT_EQ(arr(index2), 0);
|
||||
arr(index1) = 51;
|
||||
arr(index2) = 61;
|
||||
EXPECT_EQ(arr(1, 1), 51);
|
||||
EXPECT_EQ(arr(1, 2), 61);
|
||||
}
|
||||
|
||||
TEST(ArrayTest, IndexingReadWriteBool) {
|
||||
Array<bool> arr{{false, true, false}, {false, true, false}};
|
||||
|
||||
@ -141,5 +154,37 @@ TEST(ArrayTest, Each) {
|
||||
EXPECT_EQ(arr.num_elements() * (arr.num_elements() - 1) / 2, each_sum);
|
||||
}
|
||||
|
||||
TEST(ArrayTest, Slice) {
|
||||
Array<int64> arr({2, 4});
|
||||
arr.FillWithMultiples(1);
|
||||
|
||||
Array<int64> identity_slice = arr.Slice({0, 0}, {2, 4});
|
||||
EXPECT_EQ(identity_slice.dimensions(), arr.dimensions());
|
||||
for (auto it1 = arr.begin(), it2 = identity_slice.begin(), e = arr.end();
|
||||
it1 != e; ++it1, ++it2) {
|
||||
EXPECT_EQ(*it1, *it2);
|
||||
}
|
||||
|
||||
Array<int64> sub_slice = arr.Slice({1, 0}, {2, 2});
|
||||
EXPECT_EQ(sub_slice.dimensions(), (std::vector<int64>{1, 2}));
|
||||
const string expected = R"([[4, 5]])";
|
||||
EXPECT_EQ(expected, sub_slice.ToString());
|
||||
}
|
||||
|
||||
TEST(ArrayTest, UpdateSlice) {
|
||||
Array<int64> arr({3, 4});
|
||||
arr.FillWithMultiples(1);
|
||||
|
||||
Array<int64> sub_arr({2, 2});
|
||||
sub_arr.FillWithMultiples(3);
|
||||
|
||||
arr.UpdateSlice(sub_arr, {1, 1});
|
||||
|
||||
const string expected = R"([[0, 1, 2, 3],
|
||||
[4, 0, 3, 7],
|
||||
[8, 6, 9, 11]])";
|
||||
EXPECT_EQ(expected, arr.ToString());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -68,6 +68,7 @@ class ShardingBuilder {
|
||||
const TileAssignment& tile_assignment) {
|
||||
OpSharding result;
|
||||
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
|
||||
*result.mutable_tile_shape() = tile_shape;
|
||||
for (int64 dim : tile_assignment.dimensions()) {
|
||||
result.add_tile_assignment_dimensions(dim);
|
||||
}
|
||||
|
@ -863,6 +863,11 @@ class HloInstruction {
|
||||
return *window_;
|
||||
}
|
||||
|
||||
// Sets the window data in a windowed operation such as convolution.
|
||||
void set_window(const Window& window) {
|
||||
window_ = MakeUnique<Window>(window);
|
||||
}
|
||||
|
||||
// Returns the padding configuration for a pad node.
|
||||
//
|
||||
// Precondition: opcode() == HloOpcode::kPad
|
||||
|
@ -249,7 +249,8 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
|
||||
return HloSharding(tuple_shardings);
|
||||
} else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
|
||||
return Replicate();
|
||||
} else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) {
|
||||
} else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL ||
|
||||
proto.tile_assignment_devices().size() == 1) {
|
||||
return HloSharding(proto.tile_assignment_devices(0));
|
||||
}
|
||||
// Some versions of gcc cannot infer the TileAssignment constructor from a
|
||||
|
Loading…
Reference in New Issue
Block a user