Change some std::vector to absl::InlinedVector inside Shape
We are creating a lot of instance of the Shape class during compilation and before this change each instance required multiple allocations. This change modifies the storage type for most member fields to not require memory allocation in the common case. PiperOrigin-RevId: 266142796
This commit is contained in:
parent
912db4a625
commit
2eda355277
@ -1,4 +1,4 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||
@ -244,6 +244,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
@ -265,6 +266,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
@ -85,7 +86,7 @@ xla::XlaOp TransposeInputForGroupConvolutionBackpropFilter(
|
||||
int batch_dim, int depth_dim) {
|
||||
// 1. Reshape the depth_dim C into [G, C/G]
|
||||
int num_dims = input_shape.dimensions_size();
|
||||
std::vector<int64> reshape_dims = input_shape.dimensions();
|
||||
std::vector<int64> reshape_dims = xla::SpanToVector(input_shape.dimensions());
|
||||
reshape_dims[depth_dim] = reshape_dims[depth_dim] / num_groups;
|
||||
reshape_dims.insert(reshape_dims.begin() + depth_dim, num_groups);
|
||||
xla::XlaOp result = xla::Reshape(input, reshape_dims);
|
||||
|
@ -57,7 +57,7 @@ class DepthToSpaceOp : public XlaOpKernel {
|
||||
xla::XlaBuilder* builder = input.builder();
|
||||
auto input_xla_shape = builder->GetShape(input);
|
||||
OP_REQUIRES_OK(ctx, input_xla_shape.status());
|
||||
const std::vector<int64>& input_shape =
|
||||
absl::Span<const int64> input_shape =
|
||||
input_xla_shape.ValueOrDie().dimensions();
|
||||
int input_rank = input_shape.size();
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -153,7 +154,8 @@ class ExtractImagePatchesOp : public XlaOpKernel {
|
||||
lhs_dilation, rhs_dilation, dims, depth);
|
||||
// Feature group convolution, will end up with the kernel_size change more
|
||||
// rapidly than the depth. Reshape, transpose and reshape to reorder them.
|
||||
auto conv_dims = builder->GetShape(conv).ValueOrDie().dimensions();
|
||||
std::vector<int64> conv_dims =
|
||||
xla::SpanToVector(builder->GetShape(conv).ValueOrDie().dimensions());
|
||||
conv_dims.back() = depth;
|
||||
conv_dims.push_back(kernel_size);
|
||||
conv = xla::TransposeInMinorDims(xla::Reshape(conv, conv_dims));
|
||||
|
@ -57,7 +57,7 @@ class SpaceToDepthOp : public XlaOpKernel {
|
||||
xla::XlaBuilder* builder = input.builder();
|
||||
auto input_xla_shape = builder->GetShape(input);
|
||||
OP_REQUIRES_OK(ctx, input_xla_shape.status());
|
||||
const std::vector<int64>& input_shape =
|
||||
absl::Span<const int64> input_shape =
|
||||
input_xla_shape.ValueOrDie().dimensions();
|
||||
int input_rank = input_shape.size();
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -214,7 +215,7 @@ Status GetTensorListShapeFromElementTensorListShape(
|
||||
for (int i = 0; i < tuple_size; i++) {
|
||||
const xla::Shape& shape =
|
||||
xla::ShapeUtil::GetTupleElementShape(element_tensor_list_shape, i);
|
||||
std::vector<int64> dimensions = shape.dimensions();
|
||||
std::vector<int64> dimensions = xla::SpanToVector(shape.dimensions());
|
||||
dimensions.insert(dimensions.begin(), leading_dim);
|
||||
shapes.push_back(
|
||||
xla::ShapeUtil::MakeShape(shape.element_type(), dimensions));
|
||||
@ -236,7 +237,7 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
|
||||
}
|
||||
|
||||
std::vector<xla::Shape> shapes;
|
||||
std::vector<int64> dimensions = element_shape.dimensions();
|
||||
std::vector<int64> dimensions = xla::SpanToVector(element_shape.dimensions());
|
||||
dimensions.insert(dimensions.begin(), leading_dim);
|
||||
shapes.push_back(
|
||||
xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions));
|
||||
@ -321,7 +322,8 @@ Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element,
|
||||
const xla::Shape& element_part_shape =
|
||||
xla::ShapeUtil::GetTupleElementShape(element_shape, i);
|
||||
xla::XlaOp element_part = xla::GetTupleElement(element, i);
|
||||
std::vector<int64> element_part_dims = element_part_shape.dimensions();
|
||||
std::vector<int64> element_part_dims =
|
||||
xla::SpanToVector(element_part_shape.dimensions());
|
||||
element_part_dims.insert(element_part_dims.begin(), 1);
|
||||
element_part = xla::Reshape(element_part, element_part_dims);
|
||||
|
||||
@ -337,7 +339,8 @@ Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element,
|
||||
}
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
|
||||
std::vector<int64> element_dims = element_shape.dimensions();
|
||||
std::vector<int64> element_dims =
|
||||
xla::SpanToVector(element_shape.dimensions());
|
||||
element_dims.insert(element_dims.begin(), 1);
|
||||
xla::XlaOp update = xla::Reshape(element, element_dims);
|
||||
|
||||
@ -384,7 +387,8 @@ Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result,
|
||||
xla::ConstantR0<int32>(b, 0));
|
||||
start_indices[0] = push_index;
|
||||
|
||||
std::vector<int64> slice_shape = list_part_shape.dimensions();
|
||||
std::vector<int64> slice_shape =
|
||||
xla::SpanToVector(list_part_shape.dimensions());
|
||||
slice_shape[0] = 1LL;
|
||||
|
||||
xla::XlaOp list_part = xla::GetTupleElement(list, i);
|
||||
@ -422,7 +426,8 @@ Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index,
|
||||
|
||||
xla::XlaBuilder* b = list.builder();
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
|
||||
std::vector<int64> element_dims = element_shape.dimensions();
|
||||
std::vector<int64> element_dims =
|
||||
xla::SpanToVector(element_shape.dimensions());
|
||||
element_dims.insert(element_dims.begin(), 1);
|
||||
xla::XlaOp update = xla::Reshape(element, element_dims);
|
||||
|
||||
@ -463,7 +468,7 @@ Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index,
|
||||
xla::ConstantR0<int32>(b, 0));
|
||||
start_indices[0] = index;
|
||||
|
||||
std::vector<int64> slice_shape = buffer_shape.dimensions();
|
||||
std::vector<int64> slice_shape = xla::SpanToVector(buffer_shape.dimensions());
|
||||
slice_shape[0] = 1LL;
|
||||
|
||||
xla::XlaOp list_part = xla::GetTupleElement(list, 0);
|
||||
|
@ -90,6 +90,7 @@ cc_library(
|
||||
hdrs = ["data_format.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/data_format.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -57,7 +59,8 @@ xla::StatusOr<xla::XlaOp> Expand(xla::XlaOp input, int64 dim) {
|
||||
|
||||
// Split the `dim` into two dimensions with a reshape. The size of the new
|
||||
// dimension is always 4.
|
||||
std::vector<int64> expanded_shape(input_shape.dimensions());
|
||||
std::vector<int64> expanded_shape =
|
||||
xla::SpanToVector(input_shape.dimensions());
|
||||
expanded_shape[dim] /= 4;
|
||||
expanded_shape.insert(expanded_shape.begin() + dim, 4);
|
||||
|
||||
|
@ -440,7 +440,7 @@ std::vector<int64> XlaCompiler::Argument::DimensionSizes() const {
|
||||
return xla::InlinedVectorToVector(
|
||||
absl::get<TensorShape>(shape).dim_sizes());
|
||||
} else {
|
||||
return absl::get<xla::Shape>(shape).dimensions();
|
||||
return xla::SpanToVector(absl::get<xla::Shape>(shape).dimensions());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -102,7 +102,7 @@ class SelfAdjointEigTest : public ClientLibraryTestBase {
|
||||
|
||||
XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) {
|
||||
Shape shape = builder->GetShape(result.v).ValueOrDie();
|
||||
std::vector<int64> out_dims = shape.dimensions();
|
||||
absl::Span<const int64> out_dims = shape.dimensions();
|
||||
std::vector<int64> broadcast_dims(shape.rank() - 1);
|
||||
std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
|
||||
|
||||
|
@ -223,7 +223,7 @@ XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, int64 batch_dims) {
|
||||
index = ConvertElementType(index, U32);
|
||||
index_shape.set_element_type(U32);
|
||||
}
|
||||
std::vector<int64> slice_sizes = input_shape.dimensions();
|
||||
std::vector<int64> slice_sizes = SpanToVector(input_shape.dimensions());
|
||||
GatherDimensionNumbers gather_dnums;
|
||||
gather_dnums.set_index_vector_dim(index_shape.rank());
|
||||
if (batch_dims > 0) {
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/types/span.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -53,7 +53,7 @@ class Tile {
|
||||
int64 dimension(int i) const { return dimensions_.at(i); }
|
||||
|
||||
// Returns the dimensions of the tile.
|
||||
const std::vector<int64>& dimensions() const { return dimensions_; }
|
||||
absl::Span<const int64> dimensions() const { return dimensions_; }
|
||||
|
||||
Tile& add_dimensions(int64 value) {
|
||||
dimensions_.push_back(value);
|
||||
@ -76,7 +76,7 @@ class Tile {
|
||||
|
||||
private:
|
||||
// The bounds of the tile.
|
||||
std::vector<int64> dimensions_;
|
||||
absl::InlinedVector<int64, 2> dimensions_;
|
||||
};
|
||||
|
||||
class Layout {
|
||||
@ -183,8 +183,10 @@ class Layout {
|
||||
minor_to_major_.clear();
|
||||
return *this;
|
||||
}
|
||||
const std::vector<int64>& minor_to_major() const { return minor_to_major_; }
|
||||
std::vector<int64>* mutable_minor_to_major() { return &minor_to_major_; }
|
||||
absl::Span<const int64> minor_to_major() const { return minor_to_major_; }
|
||||
absl::InlinedVector<int64, 6>* mutable_minor_to_major() {
|
||||
return &minor_to_major_;
|
||||
}
|
||||
|
||||
// Methods for accessing the tile field.
|
||||
int tiles_size() const { return tiles_.size(); }
|
||||
@ -198,8 +200,8 @@ class Layout {
|
||||
tiles_.clear();
|
||||
return *this;
|
||||
}
|
||||
const std::vector<Tile>& tiles() const { return tiles_; }
|
||||
std::vector<Tile>* mutable_tiles() { return &tiles_; }
|
||||
absl::Span<const Tile> tiles() const { return tiles_; }
|
||||
absl::InlinedVector<Tile, 2>* mutable_tiles() { return &tiles_; }
|
||||
|
||||
// Methods for accessing the int64 fields.
|
||||
int64 max_sparse_elements() const { return max_sparse_elements_; }
|
||||
@ -250,7 +252,7 @@ class Layout {
|
||||
// The second most minor is [8,100,100,3][0], which is size 8.
|
||||
// The third most minor is [8,100,100,3][2], which is size 100.
|
||||
// And the major dim is [8,100,100,3][1], which is size 100.
|
||||
std::vector<int64> minor_to_major_;
|
||||
absl::InlinedVector<int64, 6> minor_to_major_;
|
||||
|
||||
// The maximum number of elements that can be stored for SPARSE formats. This
|
||||
// can be used to determine the maximum size in bytes of arrays stored in
|
||||
@ -258,7 +260,7 @@ class Layout {
|
||||
int64 max_sparse_elements_ = 0;
|
||||
|
||||
// The tiles used in tiling-based layout.
|
||||
std::vector<Tile> tiles_;
|
||||
absl::InlinedVector<Tile, 2> tiles_;
|
||||
|
||||
// The number of bits used to store an individual array element.
|
||||
int64 element_size_in_bits_ = 0;
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <random>
|
||||
@ -41,7 +42,8 @@ namespace {
|
||||
|
||||
// Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets
|
||||
// minor_to_major to the value that represents the default layout.
|
||||
void SetDefaultLayoutToContainer(std::vector<int64>* minor_to_major) {
|
||||
template <typename T>
|
||||
void SetDefaultLayoutToContainer(T* minor_to_major) {
|
||||
// The default XLA layout is major-to-minor (dim 0 is major).
|
||||
// For more information on XLA layouts, see:
|
||||
// https://www.tensorflow.org/performance/xla/shapes
|
||||
@ -105,7 +107,7 @@ namespace {
|
||||
Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
Layout layout;
|
||||
layout.set_format(DENSE);
|
||||
std::vector<int64>* minor_to_major = layout.mutable_minor_to_major();
|
||||
auto* minor_to_major = layout.mutable_minor_to_major();
|
||||
minor_to_major->resize(rank, 0);
|
||||
SetDefaultLayoutToContainer(minor_to_major);
|
||||
return layout;
|
||||
|
@ -1635,7 +1635,8 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
|
||||
|
||||
// Invert reshape.
|
||||
CHECK_EQ(rhs_contracting_dims.size(), 1);
|
||||
auto rhs_unsquished_shape_dims = constant->shape().dimensions();
|
||||
std::vector<int64> rhs_unsquished_shape_dims =
|
||||
SpanToVector(constant->shape().dimensions());
|
||||
auto it = rhs_unsquished_shape_dims.erase(rhs_unsquished_shape_dims.begin() +
|
||||
rhs_contracting_dims[0]);
|
||||
for (auto dim : lhs_contracting_dims) {
|
||||
@ -1656,7 +1657,8 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
|
||||
absl::c_iota(rhs_contracting_dims, rhs_contracting_dims[0]);
|
||||
|
||||
// Invert transpose. First compute the shape.
|
||||
auto rhs_transpose_shape_dims = rhs_reshape->shape().dimensions();
|
||||
std::vector<int64> rhs_transpose_shape_dims =
|
||||
SpanToVector(rhs_reshape->shape().dimensions());
|
||||
it = rhs_transpose_shape_dims.erase(
|
||||
rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0],
|
||||
rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0] +
|
||||
|
@ -140,7 +140,7 @@ Status ConvolutionVisitor::HandleBackwardFilterBatchGroupConvolution(
|
||||
};
|
||||
// Reshape batch_dim C -> [G, C/G] - Batch and feature dims have been
|
||||
// swapped in tf2xla bridge
|
||||
std::vector<int64> reshape_dims = lhs->shape().dimensions();
|
||||
std::vector<int64> reshape_dims = SpanToVector(lhs->shape().dimensions());
|
||||
reshape_dims[input_batch_dimension] =
|
||||
reshape_dims[input_batch_dimension] / num_groups;
|
||||
reshape_dims.insert(reshape_dims.begin() + input_batch_dimension,
|
||||
|
@ -268,7 +268,7 @@ MatchBackwardFilter(HloInstruction* conv) {
|
||||
int64 input_feature = lhs->shape().dimensions(input_feature_dimension);
|
||||
|
||||
// Reshape batch_dim G*N -> [G,N]
|
||||
std::vector<int64> reshape_dims = lhs->shape().dimensions();
|
||||
std::vector<int64> reshape_dims = SpanToVector(lhs->shape().dimensions());
|
||||
auto num_groups = conv->feature_group_count();
|
||||
CHECK_EQ(input_batch % num_groups, 0)
|
||||
<< "Input batch should be an exact multiple of feature group count";
|
||||
@ -290,7 +290,7 @@ MatchBackwardFilter(HloInstruction* conv) {
|
||||
transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
|
||||
input_batch_dimension);
|
||||
std::vector<int64> transpose_reshape_dims =
|
||||
lhs_reshape_1->shape().dimensions();
|
||||
SpanToVector(lhs_reshape_1->shape().dimensions());
|
||||
transpose_reshape_dims.erase(transpose_reshape_dims.begin() +
|
||||
input_batch_dimension);
|
||||
transpose_reshape_dims.insert(
|
||||
@ -539,7 +539,7 @@ MatchBackwardInput(HloInstruction* conv) {
|
||||
reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
|
||||
reverse_filter->shape(), reverse_filter,
|
||||
AsInt64Slice(dnums.kernel_spatial_dimensions())));
|
||||
TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter));
|
||||
TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_num=*/1, reverse_filter));
|
||||
}
|
||||
|
||||
// Calculate the 'rhs' that goes into the backward input convolution.
|
||||
@ -572,7 +572,7 @@ MatchBackwardInput(HloInstruction* conv) {
|
||||
|
||||
// Reshape [H, W, ..., in_depth, out_depth / G] -> [H, W, ..., G, in_depth/G,
|
||||
// out_depth / G]
|
||||
std::vector<int64> reshape_dims = rhs->shape().dimensions();
|
||||
std::vector<int64> reshape_dims = SpanToVector(rhs->shape().dimensions());
|
||||
auto num_groups = conv->feature_group_count();
|
||||
CHECK_EQ(input_features % num_groups, 0)
|
||||
<< "Input feature count should be an exact multiple of feature group "
|
||||
@ -593,7 +593,8 @@ MatchBackwardInput(HloInstruction* conv) {
|
||||
transpose_dims.erase(transpose_dims.begin() + input_feature_dimension);
|
||||
transpose_dims.insert(transpose_dims.begin() + output_feature_dimension,
|
||||
input_feature_dimension);
|
||||
std::vector<int64> transpose_reshape_dims = rhs->shape().dimensions();
|
||||
std::vector<int64> transpose_reshape_dims =
|
||||
SpanToVector(rhs->shape().dimensions());
|
||||
transpose_reshape_dims.erase(transpose_reshape_dims.begin() +
|
||||
input_feature_dimension);
|
||||
transpose_reshape_dims.insert(
|
||||
|
@ -919,9 +919,8 @@ void Dft1D(int64 length, int64 start, int64 stride, bool inverse,
|
||||
|
||||
// Helper to reverse the order of dimension lengths in the passed-in literal.
|
||||
std::vector<int64> GetDimensionLengths(const Literal& literal) {
|
||||
std::vector<int64> lengths = literal.shape().dimensions();
|
||||
absl::c_reverse(lengths);
|
||||
return lengths;
|
||||
auto dimensions = literal.shape().dimensions();
|
||||
return std::vector<int64>(dimensions.rbegin(), dimensions.rend());
|
||||
}
|
||||
|
||||
// Helper to compute strides for creating linear indices into multidimensional
|
||||
@ -2373,7 +2372,6 @@ Status HloEvaluator::HandleReduce(HloInstruction* instr) {
|
||||
arg_dim_steps[dim] = 1;
|
||||
arg_dim_counts[dim] = arg_dimensions[dim];
|
||||
}
|
||||
auto reduced_dimensions = arg_shape.dimensions();
|
||||
|
||||
// Map each dimension in the result to a dimension in arg that isn't
|
||||
// being reduced.
|
||||
|
@ -964,7 +964,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
|
||||
return Construct<ScalarIndexedConstantArray>(
|
||||
new_source, scalar_indexed_const->indices(),
|
||||
scalar_indexed_const->source_dim(),
|
||||
ArraySliceToVector(scalar_indexed_const->output_dims()),
|
||||
SpanToVector(scalar_indexed_const->output_dims()),
|
||||
scalar_indexed_const->shape());
|
||||
}
|
||||
|
||||
@ -1060,7 +1060,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
|
||||
ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
|
||||
return Construct<ScalarIndexedConstantArray>(
|
||||
new_source, lhs->indices(), new_source_dim,
|
||||
ArraySliceToVector(lhs->output_dims()), shape);
|
||||
SpanToVector(lhs->output_dims()), shape);
|
||||
}
|
||||
|
||||
StatusOr<Analysis::Array*>
|
||||
@ -1096,7 +1096,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
|
||||
ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
|
||||
return Construct<ScalarIndexedConstantArray>(
|
||||
new_source, rhs->indices(), new_source_dim,
|
||||
ArraySliceToVector(rhs->output_dims()), shape);
|
||||
SpanToVector(rhs->output_dims()), shape);
|
||||
}
|
||||
|
||||
StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
|
||||
|
@ -3230,7 +3230,7 @@ Status ValidateScatterDimensionNumbers(
|
||||
/*inputs=*/1));
|
||||
|
||||
std::vector<int64> expanded_scatter_indices_shape =
|
||||
ArraySliceToVector(AsInt64Slice(scatter_indices_shape.dimensions()));
|
||||
SpanToVector(scatter_indices_shape.dimensions());
|
||||
if (expanded_scatter_indices_shape.size() ==
|
||||
scatter_dim_numbers.index_vector_dim()) {
|
||||
expanded_scatter_indices_shape.push_back(1);
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/layout.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
@ -72,7 +73,7 @@ class Shape {
|
||||
dynamic_dimensions_[dimension] = is_dynamic;
|
||||
}
|
||||
|
||||
const std::vector<bool>& dynamic_dimensions() const {
|
||||
absl::Span<const bool> dynamic_dimensions() const {
|
||||
return dynamic_dimensions_;
|
||||
}
|
||||
|
||||
@ -104,7 +105,7 @@ class Shape {
|
||||
dimensions_.clear();
|
||||
dynamic_dimensions_.clear();
|
||||
}
|
||||
const std::vector<int64>& dimensions() const { return dimensions_; }
|
||||
absl::Span<const int64> dimensions() const { return dimensions_; }
|
||||
absl::Span<int64> mutable_dimensions() { return absl::MakeSpan(dimensions_); }
|
||||
|
||||
// Methods for accessing the tuple subshapes. This field only non-empty for
|
||||
@ -219,11 +220,11 @@ class Shape {
|
||||
// The array bounds of the dimensions. This is nonempty only for array
|
||||
// shapes. For a dynamically-sized dimension, the respective value in this
|
||||
// vector is an inclusive upper limit of the array bound.
|
||||
std::vector<int64> dimensions_;
|
||||
absl::InlinedVector<int64, 6> dimensions_;
|
||||
|
||||
// This vector is the same size as 'dimensions_' and indicates whether the
|
||||
// respective dimension is dynamically sized.
|
||||
std::vector<bool> dynamic_dimensions_;
|
||||
absl::InlinedVector<bool, 6> dynamic_dimensions_;
|
||||
|
||||
// The tuple element subshapes. This is nonempty only for tuple shapes.
|
||||
std::vector<Shape> tuple_shapes_;
|
||||
|
@ -227,7 +227,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims);
|
||||
// Since the physical layout is kept the same, the tiles and element size are
|
||||
// the same also.
|
||||
*new_shape.mutable_layout()->mutable_tiles() = shape.layout().tiles();
|
||||
new_shape.mutable_layout()->mutable_tiles()->assign(
|
||||
shape.layout().tiles().begin(), shape.layout().tiles().end());
|
||||
new_shape.mutable_layout()->set_element_size_in_bits(
|
||||
shape.layout().element_size_in_bits());
|
||||
return new_shape;
|
||||
@ -1309,7 +1310,8 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified(
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
auto layout = simple_output_shape->layout().minor_to_major();
|
||||
std::vector<int64> layout =
|
||||
SpanToVector(simple_output_shape->layout().minor_to_major());
|
||||
// For each one sized dimension in the output, increment the dimension
|
||||
// numbers in layout that are more minor than the one.
|
||||
absl::InlinedVector<int64, 8> dim_map;
|
||||
|
@ -86,9 +86,9 @@ using DimensionVector = absl::InlinedVector<int64, kInlineRank>;
|
||||
XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter)
|
||||
|
||||
// Helper for macros above. Don't use directly.
|
||||
#define XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter) \
|
||||
static ::xla::TimerStats XLA_TimerStats##counter; \
|
||||
::xla::ScopedLoggingTimer XLA_ScopedLoggingTimerInstance##counter( \
|
||||
#define XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter) \
|
||||
static ::xla::TimerStats XLA_TimerStats##counter; \
|
||||
::xla::ScopedLoggingTimer XLA_ScopedLoggingTimerInstance##counter( \
|
||||
label, /*enabled=*/VLOG_IS_ON(level), &XLA_TimerStats##counter);
|
||||
|
||||
struct TimerStats {
|
||||
@ -507,7 +507,7 @@ void EraseAt(C* c, int64 index) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> ArraySliceToVector(absl::Span<const T> slice) {
|
||||
std::vector<T> SpanToVector(absl::Span<const T> slice) {
|
||||
return std::vector<T>(slice.begin(), slice.end());
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user