[XLA] Make PositionInContainer
polymorphic in container type
Change: 144283085
This commit is contained in:
parent
86ab87491b
commit
a1e2fb98f6
@ -782,9 +782,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
|
||||
input_shape.layout().minor_to_major(0) != dnums.feature_dimension() ||
|
||||
// The input feature dimension should come later in the minor-to-major
|
||||
// order.
|
||||
(PositionInContainer(AsInt64Slice(filter_shape.layout().minor_to_major()),
|
||||
(PositionInContainer(filter_shape.layout().minor_to_major(),
|
||||
dnums.kernel_input_feature_dimension()) <
|
||||
PositionInContainer(AsInt64Slice(filter_shape.layout().minor_to_major()),
|
||||
PositionInContainer(filter_shape.layout().minor_to_major(),
|
||||
dnums.kernel_output_feature_dimension()))) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -1108,18 +1108,15 @@ Status IrEmitterUnnested::EmitReductionToVector(
|
||||
int64 width = 1;
|
||||
for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
|
||||
++input_dim) {
|
||||
if (PositionInContainer(
|
||||
AsInt64Slice(input_shape.layout().minor_to_major()), input_dim) >
|
||||
PositionInContainer(
|
||||
AsInt64Slice(input_shape.layout().minor_to_major()),
|
||||
input_dim_to_keep)) {
|
||||
if (PositionInContainer(input_shape.layout().minor_to_major(),
|
||||
input_dim) >
|
||||
PositionInContainer(input_shape.layout().minor_to_major(),
|
||||
input_dim_to_keep)) {
|
||||
depth *= input_shape.dimensions(input_dim);
|
||||
} else if (PositionInContainer(
|
||||
AsInt64Slice(input_shape.layout().minor_to_major()),
|
||||
input_dim) <
|
||||
PositionInContainer(
|
||||
AsInt64Slice(input_shape.layout().minor_to_major()),
|
||||
input_dim_to_keep)) {
|
||||
} else if (PositionInContainer(input_shape.layout().minor_to_major(),
|
||||
input_dim) <
|
||||
PositionInContainer(input_shape.layout().minor_to_major(),
|
||||
input_dim_to_keep)) {
|
||||
width *= input_shape.dimensions(input_dim);
|
||||
}
|
||||
}
|
||||
|
@ -176,12 +176,6 @@ std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
|
||||
return output;
|
||||
}
|
||||
|
||||
int64 PositionInContainer(tensorflow::gtl::ArraySlice<int64> container,
|
||||
int64 value) {
|
||||
return std::find(container.begin(), container.end(), value) -
|
||||
container.begin();
|
||||
}
|
||||
|
||||
PaddingConfig MakeNoPaddingConfig(int64 rank) {
|
||||
PaddingConfig padding_config;
|
||||
for (int64 dnum = 0; dnum < rank; ++dnum) {
|
||||
|
@ -183,8 +183,11 @@ std::vector<int64> InversePermutation(
|
||||
std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
|
||||
tensorflow::gtl::ArraySlice<int64> p2);
|
||||
|
||||
int64 PositionInContainer(tensorflow::gtl::ArraySlice<int64> container,
|
||||
int64 value);
|
||||
template <typename Container>
|
||||
int64 PositionInContainer(const Container& container, int64 value) {
|
||||
return std::distance(container.begin(),
|
||||
std::find(container.begin(), container.end(), value));
|
||||
}
|
||||
|
||||
// Returns a PaddingConfig object that represents no padding for the given rank.
|
||||
PaddingConfig MakeNoPaddingConfig(int64 rank);
|
||||
|
Loading…
x
Reference in New Issue
Block a user