From 401387665d3fb80ddcb0fb216d780ac04d86933f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 11 Jan 2017 20:09:44 -0800 Subject: [PATCH] [XLA] Use `ShapeUtil::PermuteDimensions` in `ShapeInference::InferTransposeShape` Change: 144282152 --- tensorflow/compiler/xla/service/shape_inference.cc | 4 +--- tensorflow/compiler/xla/service/shape_inference_test.cc | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 11559ad7578..fbab2dfd4af 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1319,9 +1319,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However, // we need output[i]=input[dimensions[i]] which is // Permute(Inverse(dimensions),input). - return ShapeUtil::MakeShape(operand.element_type(), - Permute(InversePermutation(dimensions), - AsInt64Slice(operand.dimensions()))); + return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand); } /* static */ StatusOr ShapeInference::InferSelectShape( diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 10fd4e53c5c..5a1ae6b0024 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1125,8 +1125,8 @@ TEST_F(ShapeInferenceTest, Transpose) { ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0}); EXPECT_IS_OK(inferred_shape_and_status); Shape inferred_shape = inferred_shape_and_status.ValueOrDie(); - EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, - ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); + EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape, + ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); } } // namespace