[XLA] Make PositionInContainer polymorphic in container type

Change: 144283085
This commit is contained in:
A. Unique TensorFlower 2017-01-11 20:28:09 -08:00 committed by TensorFlower Gardener
parent 86ab87491b
commit a1e2fb98f6
4 changed files with 15 additions and 21 deletions

View File

@ -782,9 +782,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
input_shape.layout().minor_to_major(0) != dnums.feature_dimension() ||
// The input feature dimension should come later in the minor-to-major
// order.
(PositionInContainer(AsInt64Slice(filter_shape.layout().minor_to_major()),
(PositionInContainer(filter_shape.layout().minor_to_major(),
dnums.kernel_input_feature_dimension()) <
PositionInContainer(AsInt64Slice(filter_shape.layout().minor_to_major()),
PositionInContainer(filter_shape.layout().minor_to_major(),
dnums.kernel_output_feature_dimension()))) {
return Status::OK();
}

View File

@ -1108,18 +1108,15 @@ Status IrEmitterUnnested::EmitReductionToVector(
int64 width = 1;
for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
++input_dim) {
if (PositionInContainer(
AsInt64Slice(input_shape.layout().minor_to_major()), input_dim) >
PositionInContainer(
AsInt64Slice(input_shape.layout().minor_to_major()),
input_dim_to_keep)) {
if (PositionInContainer(input_shape.layout().minor_to_major(),
input_dim) >
PositionInContainer(input_shape.layout().minor_to_major(),
input_dim_to_keep)) {
depth *= input_shape.dimensions(input_dim);
} else if (PositionInContainer(
AsInt64Slice(input_shape.layout().minor_to_major()),
input_dim) <
PositionInContainer(
AsInt64Slice(input_shape.layout().minor_to_major()),
input_dim_to_keep)) {
} else if (PositionInContainer(input_shape.layout().minor_to_major(),
input_dim) <
PositionInContainer(input_shape.layout().minor_to_major(),
input_dim_to_keep)) {
width *= input_shape.dimensions(input_dim);
}
}

View File

@ -176,12 +176,6 @@ std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
return output;
}
int64 PositionInContainer(tensorflow::gtl::ArraySlice<int64> container,
int64 value) {
return std::find(container.begin(), container.end(), value) -
container.begin();
}
PaddingConfig MakeNoPaddingConfig(int64 rank) {
PaddingConfig padding_config;
for (int64 dnum = 0; dnum < rank; ++dnum) {

View File

@ -183,8 +183,11 @@ std::vector<int64> InversePermutation(
std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
tensorflow::gtl::ArraySlice<int64> p2);
int64 PositionInContainer(tensorflow::gtl::ArraySlice<int64> container,
int64 value);
template <typename Container>
int64 PositionInContainer(const Container& container, int64 value) {
return std::distance(container.begin(),
std::find(container.begin(), container.end(), value));
}
// Returns a PaddingConfig object that represents no padding for the given rank.
PaddingConfig MakeNoPaddingConfig(int64 rank);