replace ConstantOp by TF::ConstOp in tf.Slice arguments, because TF::ConstOp is expected by legalize_tf_patterns.td

PiperOrigin-RevId: 298485749
Change-Id: Id25ab83c71177b4af79e2837e7f19d2a12237899
This commit is contained in:
A. Unique TensorFlower 2020-03-02 17:17:53 -08:00 committed by TensorFlower Gardener
parent 188c063dfa
commit f291d9cefe
2 changed files with 49 additions and 49 deletions

View File

@ -5,19 +5,19 @@ func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulV2TwoDim
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[3, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[4, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[5, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_9:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_10:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK: %[[cst_11:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
@ -67,16 +67,16 @@ func @batchMatMulV2FlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>)
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulV2FlatInput
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[3, 4, 6]> : tensor<3xi64>}
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
@ -122,19 +122,19 @@ func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>)
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulTwoDim
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[3, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[4, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[5, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_9:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_10:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK: %[[cst_11:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
@ -184,16 +184,16 @@ func @batchMatMulFlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulFlatInput
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[3, 4, 6]> : tensor<3xi64>}
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>

View File

@ -71,7 +71,7 @@ TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
Type resultType = RankedTensorType::get(shape, element_type);
auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape);
auto shape_tensor =
rewriter.create<ConstantOp>(loc, shape_spec_type, constant_attr);
rewriter.create<TF::ConstOp>(loc, shape_spec_type, constant_attr);
return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
/*shape=*/shape_tensor);
}
@ -104,8 +104,8 @@ std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
auto begin_attr =
DenseElementsAttr::get<int64_t>(vector3_type, {batch_idx, 0, 0});
auto size_attr = DenseElementsAttr::get<int64_t>(vector3_type, slice_size);
auto begin = rewriter.create<ConstantOp>(loc, vector3_type, begin_attr);
auto size = rewriter.create<ConstantOp>(loc, vector3_type, size_attr);
auto begin = rewriter.create<TF::ConstOp>(loc, vector3_type, begin_attr);
auto size = rewriter.create<TF::ConstOp>(loc, vector3_type, size_attr);
auto slice_op = rewriter.create<TF::SliceOp>(loc, slice_result_type,
/*input=*/reshape_op.output(),
begin, size);