Correct SpaceToBatch result shape inference

Results of SpaceToBatch Tf-to-Tf lowering did not infer all
result types. Updated pass to correct this.

PiperOrigin-RevId: 341503610
Change-Id: Iac5e909556a8bf96886fa28c50f1af7da46c333c
This commit is contained in:
Robert Suderman 2020-11-09 15:57:42 -08:00 committed by TensorFlower Gardener
parent b3d45cd17c
commit af3f3f9111
2 changed files with 48 additions and 14 deletions

View File

@ -231,11 +231,11 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens
// CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() {value = dense<0> : tensor<1x2xi64>}
// CHECK-DAG: [[ZERO_I32:%.+]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: [[ZERO_I64:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
// CHECK-DAG: [[ONE_I64:%.+]] = "tf.Const"() {value = dense<1> : tensor<i64>}
// CHECK-DAG: [[FULL_PADDINGS:%.+]] = "tf.ConcatV2"([[PAD00]], %arg2, [[PAD00]], [[ZERO_I64]])
// CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
// CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]])
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Sum"([[FULL_PADDINGS]], [[ONE_I64]])
// CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) {axis = 1 : i64}
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Add"([[PADDINGS]]#0, [[PADDINGS]]#1)
// CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>}
// CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
// CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]])
@ -256,14 +256,25 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens
}
// Verify the result shape for the tf.PadV2 op.
// CHECK-LABEL: const_paddings_space_to_batch_nd
func @const_paddings_space_to_batch_nd(%arg0: tensor<1x8x2xf32>) -> (tensor<3x5x2xf32>) {
%0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<[[3, 4]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
// CHECK: "tf.PadV2"
// CHECK-SAME: tensor<1x5x2xf32>
// CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<[3, 5, 2]> : tensor<3xi64>}
// CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<[1, 5, 3, 2]> : tensor<4xi64>}
// CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<{{\[\[}}0, 0], [3, 4], [0, 0{{\]\]}}> : tensor<3x2xi64>}
// CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
// CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi64>}
// CHECK-DAG: [[VAL5:%.+]] = "tf.PadV2"(%arg0, [[VAL2]], [[VAL3]])
// CHECK-SAME: tensor<1x15x2xf32>
// CHECK-DAG: [[VAL6:%.+]] = "tf.Reshape"([[VAL5]], [[VAL1]])
// CHECK-DAG: [[VAL7:%.+]] = "tf.Transpose"([[VAL6]], [[VAL4]])
// CHECK-DAG: [[VAL8:%.+]] = "tf.Reshape"([[VAL7]], [[VAL0]])
%2 = "tf.SpaceToBatchND"(%arg0, %0, %1) : (tensor<1x8x2xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<3x5x2xf32>
// CHECK: return [[VAL8]]
return %2 : tensor<3x5x2xf32>
}

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/core/util/tensor_format.h"
@ -805,8 +806,8 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
ConcatV2Op::getOperationName(),
AddOp::getOperationName(),
PadOp::getOperationName(),
SumOp::getOperationName(),
SplitOp::getOperationName(),
UnpackOp::getOperationName(),
DivOp::getOperationName(),
MulOp::getOperationName(),
ReshapeOp::getOperationName(),
@ -867,6 +868,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
// full_paddings won't be available as a constant for shape inference.
ElementsAttr block_shape;
ElementsAttr paddings;
llvm::SmallVector<int64_t, 4> block_shape_ints;
auto padded_shape = llvm::to_vector<4>(input_shape);
if (matchPattern(op.block_shape(), m_Constant(&block_shape)) &&
matchPattern(op.paddings(), m_Constant(&paddings))) {
@ -876,13 +878,14 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
paddings.getValue({i, 1}).cast<IntegerAttr>().getInt();
int64_t block_shape_i =
block_shape.getValue({i}).cast<IntegerAttr>().getInt();
padded_shape[i + 1] =
(paddings_sum + padded_shape[i + 1]) / block_shape_i;
padded_shape[i + 1] = (paddings_sum + input_shape[i + 1]);
block_shape_ints.push_back(block_shape_i);
}
} else {
for (int i = 0; i < block_rank; i++) {
padded_shape[i + 1] = ShapedType::kDynamicSize;
}
block_shape_ints.resize(block_shape_type.getNumElements(), -1);
}
auto padded_type =
@ -893,13 +896,13 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
auto paddings_sum_type =
RankedTensorType::get({input_rank}, rewriter.getIntegerType(64));
auto one_i64 = rewriter.create<ConstOp>(
loc, GetScalarOfType(rewriter.getIntegerType(64), 1));
// paddings_sum = paddings[*,0] + paddings[*,1]
auto paddings_sum =
rewriter.create<SumOp>(loc, paddings_sum_type, full_paddings, one_i64);
auto paddings_split = rewriter.create<UnpackOp>(
loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings,
rewriter.getI64IntegerAttr(1));
auto paddings_sum = rewriter.create<AddOp>(loc, paddings_split.getResult(0),
paddings_split.getResult(1));
// input_shape_tensor = input.shape
auto input_shape_tensor = rewriter.create<ConstOp>(
loc,
DenseElementsAttr::get(
@ -928,25 +931,46 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
block_shape_i64)
.output());
SmallVector<int64_t, 4> outer_shape_ints;
SmallVector<Value, 4> outer_shape_vals;
for (int64_t i = 0; i < block_rank; ++i) {
// TODO(b/157475606): Insert tf.Assert that the following division has
// remainder 0.
outer_shape_vals.push_back(rewriter.create<DivOp>(
loc, padded_shape_splits[1 + i], block_shape_splits[i]));
auto padded_shape_i = padded_shape[1 + i];
auto block_shape_ints_i = block_shape_ints[i];
// Compute the outer_shape constant values to infer the reshape.
if (padded_shape_i == -1 || block_shape_ints_i == -1) {
outer_shape_ints.push_back(-1);
} else {
outer_shape_ints.push_back(padded_shape_i / block_shape_ints_i);
}
}
SmallVector<Value, 6> reshaped_shape_vals{padded_shape_splits[0]};
SmallVector<int64_t, 6> reshaped_shape_ints{padded_shape[0]};
for (int64_t i = 0; i < block_rank; ++i) {
reshaped_shape_vals.push_back(outer_shape_vals[i]);
reshaped_shape_vals.push_back(block_shape_splits[i]);
reshaped_shape_ints.push_back(outer_shape_ints[i]);
reshaped_shape_ints.push_back(block_shape_ints[i]);
}
for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
reshaped_shape_vals.push_back(padded_shape_splits[i]);
reshaped_shape_ints.push_back(padded_shape[i]);
}
auto reshaped_shape = ValuesToRank1(
rewriter, loc, rewriter.getIntegerType(64), reshaped_shape_vals);
auto reshaped = rewriter.create<ReshapeOp>(
loc,
RankedTensorType::get(reshaped_shape_ints, input_type.getElementType()),
padded, reshaped_shape);
SmallVector<int64_t, 6> permutation_vals;
for (int64_t i = 0; i < block_rank; ++i) {
permutation_vals.push_back(2 + 2 * i);
@ -961,6 +985,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
auto permutation = rewriter.create<ConstOp>(
loc, GetI64ElementsAttr(permutation_vals, &rewriter));
auto permuted = rewriter.create<TransposeOp>(loc, reshaped, permutation);
auto output_batch = padded_shape_splits[0];
for (int64_t i = 0; i < block_rank; ++i) {
output_batch =
@ -975,8 +1000,6 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
}
auto output_shape = ValuesToRank1(
rewriter, loc, rewriter.getIntegerType(64), output_shape_vals);
auto reshaped = rewriter.create<ReshapeOp>(loc, padded, reshaped_shape);
auto permuted = rewriter.create<TransposeOp>(loc, reshaped, permutation);
// Sometimes the result type is more specific than what the reshape builder
// can infer.