diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index eba06465b50..bed926c416d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -68,6 +68,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/tensor_format.h" @@ -1332,10 +1333,10 @@ LogicalResult SpaceToBatchNDOp::inferReturnTypes( // The rest of the dimension sizes can be calculated when block_shape and // paddings arguments are constant. - ElementsAttr block_shape_attr; - ElementsAttr paddings_attr; - if (matchPattern(block_shape_val, m_Constant(&block_shape_attr)) && - matchPattern(paddings_val, m_Constant(&paddings_attr))) { + DenseIntElementsAttr block_shape_attr; + DenseIntElementsAttr paddings_attr; + if (GetValueAsConstant(block_shape_val, block_shape_attr) && + GetValueAsConstant(paddings_val, paddings_attr)) { int64_t return_batch = input_shape[0]; for (uint64_t i = 0; i < block_rank; ++i) { // Propagate dynamic dimension. @@ -1347,10 +1348,10 @@ LogicalResult SpaceToBatchNDOp::inferReturnTypes( continue; } int64_t paddings_sum = - paddings_attr.getValue({i, 0}).cast().getInt() + - paddings_attr.getValue({i, 1}).cast().getInt(); + paddings_attr.getValue({i, 0}).getSExtValue() + + paddings_attr.getValue({i, 1}).getSExtValue(); int64_t block_shape_i = - block_shape_attr.getValue({i}).cast().getInt(); + block_shape_attr.getValue({i}).getSExtValue(); return_batch *= block_shape_i; return_shape[1 + i] = (paddings_sum + input_shape[i + 1]) / block_shape_i; } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 50a90b7b957..98fe02e5dea 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -1136,4 +1136,14 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor, tensor<3x3x3x16xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -} + + // CHECK-LABEL: check_walking_identity + func @check_walking_identity(%arg0 : tensor<1x192x256x128xf32>) { + %0 = "tf.Const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tf.Const"() {value = dense<2> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<2x2xi32>) -> tensor<2x2xi32> + // CHECK: SpaceToBatchND{{.*}}-> tensor<4x98x130x128xf32> + %3 = "tf.SpaceToBatchND"(%arg0, %0, %2) {device = ""} : (tensor<1x192x256x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?x128xf32> + return + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h index bd81cae5730..e0280698a97 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h @@ -47,6 +47,23 @@ inline void CopyDeviceAndUnderscoredAttributes(Operation *from, Operation *to) { }); } +// Forward declare these passthrough ops. +// TODO(jpienaar): Remove these and use trait instead. +class IdentityOp; +class IdentityNOp; + +// Returns if a value corresponds to a constant, returns the matched constant +// as an attribute. +template +bool GetValueAsConstant(Value val, AttrT &attr) { + while (auto result = val.dyn_cast()) { + Operation *op = result.getOwner(); + if (!isa(op) && !isa(op)) break; + val = op->getOperand(result.getResultNumber()); + } + return matchPattern(val, m_Constant(&attr)); +} + } // namespace TF } // namespace mlir