replace ConstantOp by TF::ConstOp in tf.Slice arguments, because TF::ConstOp is expected by legalize_tf_patterns.td
PiperOrigin-RevId: 298485749 Change-Id: Id25ab83c71177b4af79e2837e7f19d2a12237899
This commit is contained in:
parent
188c063dfa
commit
f291d9cefe
@ -5,19 +5,19 @@ func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>
|
||||
return %0 : tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulV2TwoDim
|
||||
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
|
||||
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[3, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[4, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[5, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_9:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_10:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_11:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
@ -67,16 +67,16 @@ func @batchMatMulV2FlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>)
|
||||
return %0 : tensor<3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulV2FlatInput
|
||||
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[3, 4, 6]> : tensor<3xi64>}
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
@ -122,19 +122,19 @@ func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>)
|
||||
return %0 : tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulTwoDim
|
||||
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
|
||||
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[3, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[4, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[5, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_9:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_10:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_11:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
@ -184,16 +184,16 @@ func @batchMatMulFlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -
|
||||
return %0 : tensor<3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulFlatInput
|
||||
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[3, 4, 6]> : tensor<3xi64>}
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
|
@ -71,7 +71,7 @@ TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
|
||||
Type resultType = RankedTensorType::get(shape, element_type);
|
||||
auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape);
|
||||
auto shape_tensor =
|
||||
rewriter.create<ConstantOp>(loc, shape_spec_type, constant_attr);
|
||||
rewriter.create<TF::ConstOp>(loc, shape_spec_type, constant_attr);
|
||||
return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
|
||||
/*shape=*/shape_tensor);
|
||||
}
|
||||
@ -104,8 +104,8 @@ std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
|
||||
auto begin_attr =
|
||||
DenseElementsAttr::get<int64_t>(vector3_type, {batch_idx, 0, 0});
|
||||
auto size_attr = DenseElementsAttr::get<int64_t>(vector3_type, slice_size);
|
||||
auto begin = rewriter.create<ConstantOp>(loc, vector3_type, begin_attr);
|
||||
auto size = rewriter.create<ConstantOp>(loc, vector3_type, size_attr);
|
||||
auto begin = rewriter.create<TF::ConstOp>(loc, vector3_type, begin_attr);
|
||||
auto size = rewriter.create<TF::ConstOp>(loc, vector3_type, size_attr);
|
||||
auto slice_op = rewriter.create<TF::SliceOp>(loc, slice_result_type,
|
||||
/*input=*/reshape_op.output(),
|
||||
begin, size);
|
||||
|
Loading…
Reference in New Issue
Block a user