Fixed a bug where ConvertTFStridedSlice::RewriteNewAxisMask assumed the strided_slice input had RankedTensorType.

PiperOrigin-RevId: 344367093
Change-Id: I3e6fe3ddc9ed4dd74355014cd63e8eac2db39f9e
This commit is contained in:
A. Unique TensorFlower 2020-11-25 22:24:46 -08:00 committed by TensorFlower Gardener
parent 202ef04a97
commit dd35a3e925
2 changed files with 1 additions and 14 deletions

View File

@ -748,13 +748,4 @@ func @depthwise_conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x
// CHECK: "tf.DepthwiseConv2dNative"
}
// CHECK-LABEL: strided_slice_unranked_input
func @strided_slice_unranked_input(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
%18 = "tf.Const"() {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
%57 = "tf.Const"() {value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
%534 = "tf.StridedSlice"(%arg0, %57, %57, %18) {begin_mask = 11 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 11 : i64, new_axis_mask = 4 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<*xf32>
return %534 : tensor<*xf32>
// CHECK: "tf.StridedSlice"
}
}

View File

@ -527,11 +527,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
// Insert a new reshape op.
Value original_input = strided_slice_op.input();
RankedTensorType original_input_type =
original_input.getType().dyn_cast<RankedTensorType>();
if (!original_input_type) {
return failure();
}
original_input.getType().cast<RankedTensorType>();
const ArrayRef<int64_t> &original_input_shape =
original_input_type.getShape();
SmallVector<int64_t, 4> revised_shape;