Walk identity nodes when matching constants
This is a workaround that does matching by skipping over identity nodes locally. This needs to be replaced by a more general approach. PiperOrigin-RevId: 357785926 Change-Id: I27cb65bccd641bebbb468d9515272b9054e4e29b
This commit is contained in:
parent
bb0478bdfc
commit
11baf03f4e
@ -68,6 +68,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
@ -1332,10 +1333,10 @@ LogicalResult SpaceToBatchNDOp::inferReturnTypes(
|
||||
|
||||
// The rest of the dimension sizes can be calculated when block_shape and
|
||||
// paddings arguments are constant.
|
||||
ElementsAttr block_shape_attr;
|
||||
ElementsAttr paddings_attr;
|
||||
if (matchPattern(block_shape_val, m_Constant(&block_shape_attr)) &&
|
||||
matchPattern(paddings_val, m_Constant(&paddings_attr))) {
|
||||
DenseIntElementsAttr block_shape_attr;
|
||||
DenseIntElementsAttr paddings_attr;
|
||||
if (GetValueAsConstant(block_shape_val, block_shape_attr) &&
|
||||
GetValueAsConstant(paddings_val, paddings_attr)) {
|
||||
int64_t return_batch = input_shape[0];
|
||||
for (uint64_t i = 0; i < block_rank; ++i) {
|
||||
// Propagate dynamic dimension.
|
||||
@ -1347,10 +1348,10 @@ LogicalResult SpaceToBatchNDOp::inferReturnTypes(
|
||||
continue;
|
||||
}
|
||||
int64_t paddings_sum =
|
||||
paddings_attr.getValue({i, 0}).cast<IntegerAttr>().getInt() +
|
||||
paddings_attr.getValue({i, 1}).cast<IntegerAttr>().getInt();
|
||||
paddings_attr.getValue<APInt>({i, 0}).getSExtValue() +
|
||||
paddings_attr.getValue<APInt>({i, 1}).getSExtValue();
|
||||
int64_t block_shape_i =
|
||||
block_shape_attr.getValue({i}).cast<IntegerAttr>().getInt();
|
||||
block_shape_attr.getValue<APInt>({i}).getSExtValue();
|
||||
return_batch *= block_shape_i;
|
||||
return_shape[1 + i] = (paddings_sum + input_shape[i + 1]) / block_shape_i;
|
||||
}
|
||||
|
||||
@ -1136,4 +1136,14 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
%0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<?x?x?x3xf32>, tensor<3x3x3x16xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: check_walking_identity
|
||||
func @check_walking_identity(%arg0 : tensor<1x192x256x128xf32>) {
|
||||
%0 = "tf.Const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%1 = "tf.Const"() {value = dense<2> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||
%2 = "tf.Identity"(%1) {device = ""} : (tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
// CHECK: SpaceToBatchND{{.*}}-> tensor<4x98x130x128xf32>
|
||||
%3 = "tf.SpaceToBatchND"(%arg0, %0, %2) {device = ""} : (tensor<1x192x256x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?x128xf32>
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -47,6 +47,23 @@ inline void CopyDeviceAndUnderscoredAttributes(Operation *from, Operation *to) {
|
||||
});
|
||||
}
|
||||
|
||||
// Forward declare these passthrough ops.
|
||||
// TODO(jpienaar): Remove these and use trait instead.
|
||||
class IdentityOp;
|
||||
class IdentityNOp;
|
||||
|
||||
// Returns if a value corresponds to a constant, returns the matched constant
|
||||
// as an attribute.
|
||||
template <typename AttrT>
|
||||
bool GetValueAsConstant(Value val, AttrT &attr) {
|
||||
while (auto result = val.dyn_cast<OpResult>()) {
|
||||
Operation *op = result.getOwner();
|
||||
if (!isa<IdentityOp>(op) && !isa<IdentityNOp>(op)) break;
|
||||
val = op->getOperand(result.getResultNumber());
|
||||
}
|
||||
return matchPattern(val, m_Constant(&attr));
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user