Fixed a bug where ConvertTFStridedSlice::RewriteNewAxisMask assumed the strided_slice input had RankedTensorType.
PiperOrigin-RevId: 344367093 Change-Id: I3e6fe3ddc9ed4dd74355014cd63e8eac2db39f9e
This commit is contained in:
parent
202ef04a97
commit
dd35a3e925
@ -748,13 +748,4 @@ func @depthwise_conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x
|
||||
// CHECK: "tf.DepthwiseConv2dNative"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: strided_slice_unranked_input
|
||||
func @strided_slice_unranked_input(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
|
||||
%18 = "tf.Const"() {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%57 = "tf.Const"() {value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%534 = "tf.StridedSlice"(%arg0, %57, %57, %18) {begin_mask = 11 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 11 : i64, new_axis_mask = 4 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<*xf32>
|
||||
return %534 : tensor<*xf32>
|
||||
// CHECK: "tf.StridedSlice"
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -527,11 +527,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
// Insert a new reshape op.
|
||||
Value original_input = strided_slice_op.input();
|
||||
RankedTensorType original_input_type =
|
||||
original_input.getType().dyn_cast<RankedTensorType>();
|
||||
if (!original_input_type) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
original_input.getType().cast<RankedTensorType>();
|
||||
const ArrayRef<int64_t> &original_input_shape =
|
||||
original_input_type.getShape();
|
||||
SmallVector<int64_t, 4> revised_shape;
|
||||
|
Loading…
x
Reference in New Issue
Block a user