diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index ff480ef4980..2de6c3c16d3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -231,11 +231,11 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens // CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() {value = dense<0> : tensor<1x2xi64>} // CHECK-DAG: [[ZERO_I32:%.+]] = "tf.Const"() {value = dense<0> : tensor} // CHECK-DAG: [[ZERO_I64:%.+]] = "tf.Const"() {value = dense<0> : tensor} - // CHECK-DAG: [[ONE_I64:%.+]] = "tf.Const"() {value = dense<1> : tensor} // CHECK-DAG: [[FULL_PADDINGS:%.+]] = "tf.ConcatV2"([[PAD00]], %arg2, [[PAD00]], [[ZERO_I64]]) // CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} // CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]]) - // CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Sum"([[FULL_PADDINGS]], [[ONE_I64]]) + // CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) {axis = 1 : i64} + // CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Add"([[PADDINGS]]#0, [[PADDINGS]]#1) // CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>} // CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]]) // CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]]) @@ -256,14 +256,25 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens } // Verify the result shape for the tf.PadV2 op. +// CHECK-LABEL: const_paddings_space_to_batch_nd func @const_paddings_space_to_batch_nd(%arg0: tensor<1x8x2xf32>) -> (tensor<3x5x2xf32>) { %0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> %1 = "tf.Const"() {value = dense<[[3, 4]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32> - // CHECK: "tf.PadV2" - // CHECK-SAME: tensor<1x5x2xf32> + + // CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<[3, 5, 2]> : tensor<3xi64>} + // CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<[1, 5, 3, 2]> : tensor<4xi64>} + // CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<{{\[\[}}0, 0], [3, 4], [0, 0{{\]\]}}> : tensor<3x2xi64>} + // CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} + // CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi64>} + // CHECK-DAG: [[VAL5:%.+]] = "tf.PadV2"(%arg0, [[VAL2]], [[VAL3]]) + // CHECK-SAME: tensor<1x15x2xf32> + // CHECK-DAG: [[VAL6:%.+]] = "tf.Reshape"([[VAL5]], [[VAL1]]) + // CHECK-DAG: [[VAL7:%.+]] = "tf.Transpose"([[VAL6]], [[VAL4]]) + // CHECK-DAG: [[VAL8:%.+]] = "tf.Reshape"([[VAL7]], [[VAL0]]) %2 = "tf.SpaceToBatchND"(%arg0, %0, %1) : (tensor<1x8x2xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<3x5x2xf32> + // CHECK: return [[VAL8]] return %2 : tensor<3x5x2xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index 1b8f7d2f596..53a73ce89e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/util/tensor_format.h" @@ -805,8 +806,8 @@ class LowerSpaceToBatchNDOp : public RewritePattern { ConcatV2Op::getOperationName(), AddOp::getOperationName(), PadOp::getOperationName(), - SumOp::getOperationName(), SplitOp::getOperationName(), + UnpackOp::getOperationName(), DivOp::getOperationName(), MulOp::getOperationName(), ReshapeOp::getOperationName(), @@ -867,6 +868,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern { // full_paddings won't be available as a constant for shape inference. ElementsAttr block_shape; ElementsAttr paddings; + llvm::SmallVector block_shape_ints; auto padded_shape = llvm::to_vector<4>(input_shape); if (matchPattern(op.block_shape(), m_Constant(&block_shape)) && matchPattern(op.paddings(), m_Constant(&paddings))) { @@ -876,13 +878,14 @@ class LowerSpaceToBatchNDOp : public RewritePattern { paddings.getValue({i, 1}).cast().getInt(); int64_t block_shape_i = block_shape.getValue({i}).cast().getInt(); - padded_shape[i + 1] = - (paddings_sum + padded_shape[i + 1]) / block_shape_i; + padded_shape[i + 1] = (paddings_sum + input_shape[i + 1]); + block_shape_ints.push_back(block_shape_i); } } else { for (int i = 0; i < block_rank; i++) { padded_shape[i + 1] = ShapedType::kDynamicSize; } + block_shape_ints.resize(block_shape_type.getNumElements(), -1); } auto padded_type = @@ -893,13 +896,13 @@ class LowerSpaceToBatchNDOp : public RewritePattern { auto paddings_sum_type = RankedTensorType::get({input_rank}, rewriter.getIntegerType(64)); - auto one_i64 = rewriter.create( - loc, GetScalarOfType(rewriter.getIntegerType(64), 1)); // paddings_sum = paddings[*,0] + paddings[*,1] - auto paddings_sum = - rewriter.create(loc, paddings_sum_type, full_paddings, one_i64); + auto paddings_split = rewriter.create( + loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings, + rewriter.getI64IntegerAttr(1)); + auto paddings_sum = rewriter.create(loc, paddings_split.getResult(0), + paddings_split.getResult(1)); - // input_shape_tensor = input.shape auto input_shape_tensor = rewriter.create( loc, DenseElementsAttr::get( @@ -928,25 +931,46 @@ class LowerSpaceToBatchNDOp : public RewritePattern { block_shape_i64) .output()); + SmallVector outer_shape_ints; SmallVector outer_shape_vals; for (int64_t i = 0; i < block_rank; ++i) { // TODO(b/157475606): Insert tf.Assert that the following division has // remainder 0. outer_shape_vals.push_back(rewriter.create( loc, padded_shape_splits[1 + i], block_shape_splits[i])); + + auto padded_shape_i = padded_shape[1 + i]; + auto block_shape_ints_i = block_shape_ints[i]; + + // Compute the outer_shape constant values to infer the reshape. + if (padded_shape_i == -1 || block_shape_ints_i == -1) { + outer_shape_ints.push_back(-1); + } else { + outer_shape_ints.push_back(padded_shape_i / block_shape_ints_i); + } } SmallVector reshaped_shape_vals{padded_shape_splits[0]}; + SmallVector reshaped_shape_ints{padded_shape[0]}; for (int64_t i = 0; i < block_rank; ++i) { reshaped_shape_vals.push_back(outer_shape_vals[i]); reshaped_shape_vals.push_back(block_shape_splits[i]); + + reshaped_shape_ints.push_back(outer_shape_ints[i]); + reshaped_shape_ints.push_back(block_shape_ints[i]); } for (int64_t i = 1 + block_rank; i < input_rank; ++i) { reshaped_shape_vals.push_back(padded_shape_splits[i]); + reshaped_shape_ints.push_back(padded_shape[i]); } auto reshaped_shape = ValuesToRank1( rewriter, loc, rewriter.getIntegerType(64), reshaped_shape_vals); + auto reshaped = rewriter.create( + loc, + RankedTensorType::get(reshaped_shape_ints, input_type.getElementType()), + padded, reshaped_shape); + SmallVector permutation_vals; for (int64_t i = 0; i < block_rank; ++i) { permutation_vals.push_back(2 + 2 * i); @@ -961,6 +985,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern { auto permutation = rewriter.create( loc, GetI64ElementsAttr(permutation_vals, &rewriter)); + auto permuted = rewriter.create(loc, reshaped, permutation); auto output_batch = padded_shape_splits[0]; for (int64_t i = 0; i < block_rank; ++i) { output_batch = @@ -975,8 +1000,6 @@ class LowerSpaceToBatchNDOp : public RewritePattern { } auto output_shape = ValuesToRank1( rewriter, loc, rewriter.getIntegerType(64), output_shape_vals); - auto reshaped = rewriter.create(loc, padded, reshaped_shape); - auto permuted = rewriter.create(loc, reshaped, permutation); // Sometimes the result type is more specific than what the reshape builder // can infer.