Correct SpaceToBatch result shape inference
Results of SpaceToBatch Tf-to-Tf lowering did not infer all result types. Updated pass to correct this. PiperOrigin-RevId: 341503610 Change-Id: Iac5e909556a8bf96886fa28c50f1af7da46c333c
This commit is contained in:
parent
b3d45cd17c
commit
af3f3f9111
@ -231,11 +231,11 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens
|
|||||||
// CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() {value = dense<0> : tensor<1x2xi64>}
|
// CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() {value = dense<0> : tensor<1x2xi64>}
|
||||||
// CHECK-DAG: [[ZERO_I32:%.+]] = "tf.Const"() {value = dense<0> : tensor<i32>}
|
// CHECK-DAG: [[ZERO_I32:%.+]] = "tf.Const"() {value = dense<0> : tensor<i32>}
|
||||||
// CHECK-DAG: [[ZERO_I64:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
|
// CHECK-DAG: [[ZERO_I64:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
|
||||||
// CHECK-DAG: [[ONE_I64:%.+]] = "tf.Const"() {value = dense<1> : tensor<i64>}
|
|
||||||
// CHECK-DAG: [[FULL_PADDINGS:%.+]] = "tf.ConcatV2"([[PAD00]], %arg2, [[PAD00]], [[ZERO_I64]])
|
// CHECK-DAG: [[FULL_PADDINGS:%.+]] = "tf.ConcatV2"([[PAD00]], %arg2, [[PAD00]], [[ZERO_I64]])
|
||||||
// CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
|
// CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
|
||||||
// CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]])
|
// CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]])
|
||||||
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Sum"([[FULL_PADDINGS]], [[ONE_I64]])
|
// CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) {axis = 1 : i64}
|
||||||
|
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Add"([[PADDINGS]]#0, [[PADDINGS]]#1)
|
||||||
// CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>}
|
// CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>}
|
||||||
// CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
|
// CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
|
||||||
// CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]])
|
// CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]])
|
||||||
@ -256,14 +256,25 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify the result shape for the tf.PadV2 op.
|
// Verify the result shape for the tf.PadV2 op.
|
||||||
|
// CHECK-LABEL: const_paddings_space_to_batch_nd
|
||||||
func @const_paddings_space_to_batch_nd(%arg0: tensor<1x8x2xf32>) -> (tensor<3x5x2xf32>) {
|
func @const_paddings_space_to_batch_nd(%arg0: tensor<1x8x2xf32>) -> (tensor<3x5x2xf32>) {
|
||||||
%0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
|
%0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
%1 = "tf.Const"() {value = dense<[[3, 4]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
%1 = "tf.Const"() {value = dense<[[3, 4]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||||
|
|
||||||
// CHECK: "tf.PadV2"
|
|
||||||
// CHECK-SAME: tensor<1x5x2xf32>
|
// CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<[3, 5, 2]> : tensor<3xi64>}
|
||||||
|
// CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<[1, 5, 3, 2]> : tensor<4xi64>}
|
||||||
|
// CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<{{\[\[}}0, 0], [3, 4], [0, 0{{\]\]}}> : tensor<3x2xi64>}
|
||||||
|
// CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
|
||||||
|
// CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi64>}
|
||||||
|
// CHECK-DAG: [[VAL5:%.+]] = "tf.PadV2"(%arg0, [[VAL2]], [[VAL3]])
|
||||||
|
// CHECK-SAME: tensor<1x15x2xf32>
|
||||||
|
// CHECK-DAG: [[VAL6:%.+]] = "tf.Reshape"([[VAL5]], [[VAL1]])
|
||||||
|
// CHECK-DAG: [[VAL7:%.+]] = "tf.Transpose"([[VAL6]], [[VAL4]])
|
||||||
|
// CHECK-DAG: [[VAL8:%.+]] = "tf.Reshape"([[VAL7]], [[VAL0]])
|
||||||
%2 = "tf.SpaceToBatchND"(%arg0, %0, %1) : (tensor<1x8x2xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<3x5x2xf32>
|
%2 = "tf.SpaceToBatchND"(%arg0, %0, %1) : (tensor<1x8x2xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<3x5x2xf32>
|
||||||
|
|
||||||
|
// CHECK: return [[VAL8]]
|
||||||
return %2 : tensor<3x5x2xf32>
|
return %2 : tensor<3x5x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
@ -805,8 +806,8 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
|
|||||||
ConcatV2Op::getOperationName(),
|
ConcatV2Op::getOperationName(),
|
||||||
AddOp::getOperationName(),
|
AddOp::getOperationName(),
|
||||||
PadOp::getOperationName(),
|
PadOp::getOperationName(),
|
||||||
SumOp::getOperationName(),
|
|
||||||
SplitOp::getOperationName(),
|
SplitOp::getOperationName(),
|
||||||
|
UnpackOp::getOperationName(),
|
||||||
DivOp::getOperationName(),
|
DivOp::getOperationName(),
|
||||||
MulOp::getOperationName(),
|
MulOp::getOperationName(),
|
||||||
ReshapeOp::getOperationName(),
|
ReshapeOp::getOperationName(),
|
||||||
@ -867,6 +868,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
|
|||||||
// full_paddings won't be available as a constant for shape inference.
|
// full_paddings won't be available as a constant for shape inference.
|
||||||
ElementsAttr block_shape;
|
ElementsAttr block_shape;
|
||||||
ElementsAttr paddings;
|
ElementsAttr paddings;
|
||||||
|
llvm::SmallVector<int64_t, 4> block_shape_ints;
|
||||||
auto padded_shape = llvm::to_vector<4>(input_shape);
|
auto padded_shape = llvm::to_vector<4>(input_shape);
|
||||||
if (matchPattern(op.block_shape(), m_Constant(&block_shape)) &&
|
if (matchPattern(op.block_shape(), m_Constant(&block_shape)) &&
|
||||||
matchPattern(op.paddings(), m_Constant(&paddings))) {
|
matchPattern(op.paddings(), m_Constant(&paddings))) {
|
||||||
@ -876,13 +878,14 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
|
|||||||
paddings.getValue({i, 1}).cast<IntegerAttr>().getInt();
|
paddings.getValue({i, 1}).cast<IntegerAttr>().getInt();
|
||||||
int64_t block_shape_i =
|
int64_t block_shape_i =
|
||||||
block_shape.getValue({i}).cast<IntegerAttr>().getInt();
|
block_shape.getValue({i}).cast<IntegerAttr>().getInt();
|
||||||
padded_shape[i + 1] =
|
padded_shape[i + 1] = (paddings_sum + input_shape[i + 1]);
|
||||||
(paddings_sum + padded_shape[i + 1]) / block_shape_i;
|
block_shape_ints.push_back(block_shape_i);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < block_rank; i++) {
|
for (int i = 0; i < block_rank; i++) {
|
||||||
padded_shape[i + 1] = ShapedType::kDynamicSize;
|
padded_shape[i + 1] = ShapedType::kDynamicSize;
|
||||||
}
|
}
|
||||||
|
block_shape_ints.resize(block_shape_type.getNumElements(), -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto padded_type =
|
auto padded_type =
|
||||||
@ -893,13 +896,13 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
|
|||||||
|
|
||||||
auto paddings_sum_type =
|
auto paddings_sum_type =
|
||||||
RankedTensorType::get({input_rank}, rewriter.getIntegerType(64));
|
RankedTensorType::get({input_rank}, rewriter.getIntegerType(64));
|
||||||
auto one_i64 = rewriter.create<ConstOp>(
|
|
||||||
loc, GetScalarOfType(rewriter.getIntegerType(64), 1));
|
|
||||||
// paddings_sum = paddings[*,0] + paddings[*,1]
|
// paddings_sum = paddings[*,0] + paddings[*,1]
|
||||||
auto paddings_sum =
|
auto paddings_split = rewriter.create<UnpackOp>(
|
||||||
rewriter.create<SumOp>(loc, paddings_sum_type, full_paddings, one_i64);
|
loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings,
|
||||||
|
rewriter.getI64IntegerAttr(1));
|
||||||
|
auto paddings_sum = rewriter.create<AddOp>(loc, paddings_split.getResult(0),
|
||||||
|
paddings_split.getResult(1));
|
||||||
|
|
||||||
// input_shape_tensor = input.shape
|
|
||||||
auto input_shape_tensor = rewriter.create<ConstOp>(
|
auto input_shape_tensor = rewriter.create<ConstOp>(
|
||||||
loc,
|
loc,
|
||||||
DenseElementsAttr::get(
|
DenseElementsAttr::get(
|
||||||
@ -928,25 +931,46 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
|
|||||||
block_shape_i64)
|
block_shape_i64)
|
||||||
.output());
|
.output());
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> outer_shape_ints;
|
||||||
SmallVector<Value, 4> outer_shape_vals;
|
SmallVector<Value, 4> outer_shape_vals;
|
||||||
for (int64_t i = 0; i < block_rank; ++i) {
|
for (int64_t i = 0; i < block_rank; ++i) {
|
||||||
// TODO(b/157475606): Insert tf.Assert that the following division has
|
// TODO(b/157475606): Insert tf.Assert that the following division has
|
||||||
// remainder 0.
|
// remainder 0.
|
||||||
outer_shape_vals.push_back(rewriter.create<DivOp>(
|
outer_shape_vals.push_back(rewriter.create<DivOp>(
|
||||||
loc, padded_shape_splits[1 + i], block_shape_splits[i]));
|
loc, padded_shape_splits[1 + i], block_shape_splits[i]));
|
||||||
|
|
||||||
|
auto padded_shape_i = padded_shape[1 + i];
|
||||||
|
auto block_shape_ints_i = block_shape_ints[i];
|
||||||
|
|
||||||
|
// Compute the outer_shape constant values to infer the reshape.
|
||||||
|
if (padded_shape_i == -1 || block_shape_ints_i == -1) {
|
||||||
|
outer_shape_ints.push_back(-1);
|
||||||
|
} else {
|
||||||
|
outer_shape_ints.push_back(padded_shape_i / block_shape_ints_i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value, 6> reshaped_shape_vals{padded_shape_splits[0]};
|
SmallVector<Value, 6> reshaped_shape_vals{padded_shape_splits[0]};
|
||||||
|
SmallVector<int64_t, 6> reshaped_shape_ints{padded_shape[0]};
|
||||||
for (int64_t i = 0; i < block_rank; ++i) {
|
for (int64_t i = 0; i < block_rank; ++i) {
|
||||||
reshaped_shape_vals.push_back(outer_shape_vals[i]);
|
reshaped_shape_vals.push_back(outer_shape_vals[i]);
|
||||||
reshaped_shape_vals.push_back(block_shape_splits[i]);
|
reshaped_shape_vals.push_back(block_shape_splits[i]);
|
||||||
|
|
||||||
|
reshaped_shape_ints.push_back(outer_shape_ints[i]);
|
||||||
|
reshaped_shape_ints.push_back(block_shape_ints[i]);
|
||||||
}
|
}
|
||||||
for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
|
for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
|
||||||
reshaped_shape_vals.push_back(padded_shape_splits[i]);
|
reshaped_shape_vals.push_back(padded_shape_splits[i]);
|
||||||
|
reshaped_shape_ints.push_back(padded_shape[i]);
|
||||||
}
|
}
|
||||||
auto reshaped_shape = ValuesToRank1(
|
auto reshaped_shape = ValuesToRank1(
|
||||||
rewriter, loc, rewriter.getIntegerType(64), reshaped_shape_vals);
|
rewriter, loc, rewriter.getIntegerType(64), reshaped_shape_vals);
|
||||||
|
|
||||||
|
auto reshaped = rewriter.create<ReshapeOp>(
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get(reshaped_shape_ints, input_type.getElementType()),
|
||||||
|
padded, reshaped_shape);
|
||||||
|
|
||||||
SmallVector<int64_t, 6> permutation_vals;
|
SmallVector<int64_t, 6> permutation_vals;
|
||||||
for (int64_t i = 0; i < block_rank; ++i) {
|
for (int64_t i = 0; i < block_rank; ++i) {
|
||||||
permutation_vals.push_back(2 + 2 * i);
|
permutation_vals.push_back(2 + 2 * i);
|
||||||
@ -961,6 +985,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
|
|||||||
auto permutation = rewriter.create<ConstOp>(
|
auto permutation = rewriter.create<ConstOp>(
|
||||||
loc, GetI64ElementsAttr(permutation_vals, &rewriter));
|
loc, GetI64ElementsAttr(permutation_vals, &rewriter));
|
||||||
|
|
||||||
|
auto permuted = rewriter.create<TransposeOp>(loc, reshaped, permutation);
|
||||||
auto output_batch = padded_shape_splits[0];
|
auto output_batch = padded_shape_splits[0];
|
||||||
for (int64_t i = 0; i < block_rank; ++i) {
|
for (int64_t i = 0; i < block_rank; ++i) {
|
||||||
output_batch =
|
output_batch =
|
||||||
@ -975,8 +1000,6 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
|
|||||||
}
|
}
|
||||||
auto output_shape = ValuesToRank1(
|
auto output_shape = ValuesToRank1(
|
||||||
rewriter, loc, rewriter.getIntegerType(64), output_shape_vals);
|
rewriter, loc, rewriter.getIntegerType(64), output_shape_vals);
|
||||||
auto reshaped = rewriter.create<ReshapeOp>(loc, padded, reshaped_shape);
|
|
||||||
auto permuted = rewriter.create<TransposeOp>(loc, reshaped, permutation);
|
|
||||||
|
|
||||||
// Sometimes the result type is more specific than what the reshape builder
|
// Sometimes the result type is more specific than what the reshape builder
|
||||||
// can infer.
|
// can infer.
|
||||||
|
Loading…
Reference in New Issue
Block a user