diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index f6e405744a1..515b572b0eb 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -340,6 +340,7 @@ cc_library( name = "array", hdrs = ["array.h"], deps = [ + ":status", ":types", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index ba898d1f4e9..213e0bac6c7 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -23,8 +23,10 @@ limitations under the License. #include #include #include +#include #include +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -35,10 +37,63 @@ limitations under the License. namespace xla { +namespace array_impl { + +// conjunction +// +// Performs a compile-time logical AND operation on the passed types (which +// must have `::value` members convertible to `bool`. Short-circuits if it +// encounters any `false` members (and does not compare the `::value` members +// of any remaining arguments). +// +// This metafunction is designed to be a drop-in replacement for the C++17 +// `std::conjunction` metafunction. +template +struct conjunction; + +template +struct conjunction + : std::conditional, T>::type {}; + +template <> +struct conjunction<> : std::true_type {}; + +// A type trait that is valid when all elements in a parameter pack are of +// integral type. +template +using pack_is_integral = conjunction...>; + +// Compares three same-sized vectors elementwise. For each item in `values`, +// returns false if any of values[i] is outside the half-open range [starts[i], +// ends[i]). +template +bool all_inside_range(const C1& values, const C2& range_starts, + const C3& range_ends) { + for (size_t i = 0, e = values.size(); i < e; ++i) { + if (values[i] < range_starts[i] || values[i] >= range_ends[i]) { + return false; + } + } + return true; +} + +} // namespace array_impl + // General N dimensional array class with arbitrary value type. template class Array { public: + // Type inference can have a hard time parsing very deep initializer list + // nests, especially if one or more dimensions is one as the compiler just + // sees a single-element integer initializer. These typedefs allow casting + // explicitly with less typing. + using InitializerList1D = std::initializer_list; + using InitializerList2D = std::initializer_list; + using InitializerList3D = std::initializer_list; + using InitializerList4D = std::initializer_list; + + using value_type = T; + // Creates a new array with the specified dimensions. explicit Array(tensorflow::gtl::ArraySlice sizes) : Array(sizes, T()) {} @@ -53,7 +108,7 @@ class Array { // Creates a 2D array from the given nested initializer list. The outer // initializer list is the first dimension, the inner is the second dimension. // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3. - Array(std::initializer_list> values) + Array(InitializerList2D values) : Array(ToInt64Vector({values.size(), values.begin()->size()})) { int64 idx = 0; for (const auto& it1 : values) { @@ -67,8 +122,7 @@ class Array { // Creates a 3D array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. - Array(std::initializer_list>> - values) + Array(InitializerList3D values) : Array(ToInt64Vector({values.size(), values.begin()->size(), values.begin()->begin()->size()})) { int64 idx = 0; @@ -85,9 +139,7 @@ class Array { // Creates a 4D array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. - Array(std::initializer_list< - std::initializer_list>>> - values) + Array(InitializerList4D values) : Array(ToInt64Vector({values.size(), values.begin()->size(), values.begin()->begin()->size(), values.begin()->begin()->begin()->size()})) { @@ -173,10 +225,46 @@ class Array { } } + // Invokes a callback with the (indices, value_ptr) for each cell in the + // array. If a callback returns a non-OK status, returns that else returns + // Status::OK(). + Status EachStatus( + std::function, T*)> f) { + std::vector index(sizes_.size()); + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + Status s = f(index, &values_[i]); + if (!s.ok()) { + return s; + } + } + return Status::OK(); + } + + // Invokes a callback with the (indices, value) for each cell in the array. + // If a callback returns a non-OK status, returns that else returns + // Status::OK(). + Status EachStatus( + std::function, T)> f) const { + std::vector index(sizes_.size()); + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + Status s = f(index, values_[i]); + if (!s.ok()) { + return s; + } + } + return Status::OK(); + } + // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. + // + // The type trait is required to avoid this overload participating too + // eagerly; a parameter pack can take zero or more elements, so we must + // restrict this to only parameter packs that are all of integral type. template - const T& operator()(Dims... dims) const { + typename std::enable_if::value, + const T&>::type + operator()(Dims... dims) const { // We are using a std::array to avoid having to allocate memory in this // function for performance reasons. std::array indexes{{static_cast(dims)...}}; @@ -186,7 +274,9 @@ class Array { // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. template - T& operator()(Dims... dims) { + typename std::enable_if::value, + T&>::type + operator()(Dims... dims) { // We are using a std::array to avoid having to allocate memory in this // function for performance reasons. std::array indexes{{static_cast(dims)...}}; @@ -255,6 +345,59 @@ class Array { bool operator!=(const Array& other) const { return !(*this == other); } + // Performs the equivalent of a slice operation on this array. + Array Slice(tensorflow::gtl::ArraySlice starts, + tensorflow::gtl::ArraySlice limits) const { + CHECK_EQ(starts.size(), num_dimensions()); + CHECK_EQ(limits.size(), num_dimensions()); + + std::vector sizes; + std::transform(starts.begin(), starts.end(), limits.begin(), + std::back_inserter(sizes), + [](int64 start, int64 limit) { return limit - start; }); + Array result(sizes); + + std::vector index(sizes_.size()); + int64 slice_i = 0; + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + if (array_impl::all_inside_range(index, starts, limits)) { + // Even though the bounds of result are different to our bounds, we're + // iterating in the same order. So we can simply write successive linear + // indices instead of recalculating a multi-dimensional index. + result.values_[slice_i++] = values_[i]; + } + } + return result; + } + + // Performs the equivalent of a DynamicUpdateSlice in-place on this array. + void UpdateSlice(const Array& from, + tensorflow::gtl::ArraySlice start_indices) { + CHECK_EQ(from.num_dimensions(), num_dimensions()); + std::vector limit_indices; + std::transform(start_indices.begin(), start_indices.end(), + from.dimensions().begin(), std::back_inserter(limit_indices), + std::plus{}); + std::vector index(sizes_.size()); + int64 from_i = 0; + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + if (array_impl::all_inside_range(index, start_indices, limit_indices)) { + // Even though the bounds of from are different to our bounds, we're + // iterating in the same order. So we can simply write successive linear + // indices instead of recalculating a multi-dimensional index. + values_[i] = from.values_[from_i++]; + } + } + } + + // Performs an in-place reshape, modifying the dimensions but not the + // underlying data. + void Reshape(tensorflow::gtl::ArraySlice new_dimensions) { + int64 old_num_elements = num_elements(); + sizes_ = std::vector(new_dimensions.begin(), new_dimensions.end()); + CHECK_EQ(num_elements(), old_num_elements); + } + // Returns a string representation of the array suitable for debugging. string ToString() const { std::vector pieces; diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc index 093784f541b..8b941947747 100644 --- a/tensorflow/compiler/xla/array_test.cc +++ b/tensorflow/compiler/xla/array_test.cc @@ -71,6 +71,19 @@ TEST(ArrayTest, IndexingReadWrite) { EXPECT_EQ(arr(1, 2), 61); } +TEST(ArrayTest, DynamicIndexingReadWrite) { + Array arr({2, 3}); + + std::vector index1 = {1, 1}; + std::vector index2 = {1, 2}; + EXPECT_EQ(arr(index1), 0); + EXPECT_EQ(arr(index2), 0); + arr(index1) = 51; + arr(index2) = 61; + EXPECT_EQ(arr(1, 1), 51); + EXPECT_EQ(arr(1, 2), 61); +} + TEST(ArrayTest, IndexingReadWriteBool) { Array arr{{false, true, false}, {false, true, false}}; @@ -141,5 +154,37 @@ TEST(ArrayTest, Each) { EXPECT_EQ(arr.num_elements() * (arr.num_elements() - 1) / 2, each_sum); } +TEST(ArrayTest, Slice) { + Array arr({2, 4}); + arr.FillWithMultiples(1); + + Array identity_slice = arr.Slice({0, 0}, {2, 4}); + EXPECT_EQ(identity_slice.dimensions(), arr.dimensions()); + for (auto it1 = arr.begin(), it2 = identity_slice.begin(), e = arr.end(); + it1 != e; ++it1, ++it2) { + EXPECT_EQ(*it1, *it2); + } + + Array sub_slice = arr.Slice({1, 0}, {2, 2}); + EXPECT_EQ(sub_slice.dimensions(), (std::vector{1, 2})); + const string expected = R"([[4, 5]])"; + EXPECT_EQ(expected, sub_slice.ToString()); +} + +TEST(ArrayTest, UpdateSlice) { + Array arr({3, 4}); + arr.FillWithMultiples(1); + + Array sub_arr({2, 2}); + sub_arr.FillWithMultiples(3); + + arr.UpdateSlice(sub_arr, {1, 1}); + + const string expected = R"([[0, 1, 2, 3], + [4, 0, 3, 7], + [8, 6, 9, 11]])"; + EXPECT_EQ(expected, arr.ToString()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 8e1b4be1f3e..4c6e320557f 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -68,6 +68,7 @@ class ShardingBuilder { const TileAssignment& tile_assignment) { OpSharding result; result.set_type(OpSharding::Type::OpSharding_Type_OTHER); + *result.mutable_tile_shape() = tile_shape; for (int64 dim : tile_assignment.dimensions()) { result.add_tile_assignment_dimensions(dim); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 64a88164a70..d174f05aa6b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -863,6 +863,11 @@ class HloInstruction { return *window_; } + // Sets the window data in a windowed operation such as convolution. + void set_window(const Window& window) { + window_ = MakeUnique(window); + } + // Returns the padding configuration for a pad node. // // Precondition: opcode() == HloOpcode::kPad diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index bc5663513b9..73566634542 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -249,7 +249,8 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { return HloSharding(tuple_shardings); } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { return Replicate(); - } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) { + } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL || + proto.tile_assignment_devices().size() == 1) { return HloSharding(proto.tile_assignment_devices(0)); } // Some versions of gcc cannot infer the TileAssignment constructor from a