From e033fd5b33e5f3cfb7b075715e6d38c3de2383fd Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 14 May 2020 08:44:55 -0700 Subject: [PATCH] [TF MLIR SI] Don't constant fold, only consider result of folding This results in less changes to the module during shape inference (e.g., only shapes are changed, no constant nodes are created). Effectively this computes the folded result and then just uses that information locally. Which is conceptually more wasteful (as a subsequent canonicalize pass may need to recompute these) but is less surprising and avoids dropping attributes during this part. There is still additional changes that need to be made to avoid doing needless computations here, this mostly focuses on decreasing graph mutations. PiperOrigin-RevId: 311539328 Change-Id: Ib6daa331c1e18a6d23463aa945c87e59d253708b --- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 10 +- .../tensorflow/tests/shape_inference.mlir | 31 ++-- .../tensorflow/transforms/shape_inference.cc | 145 +++++++++++++----- 3 files changed, 137 insertions(+), 49 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 2007824369c..b21fef32cca 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -3551,12 +3551,20 @@ OpFoldResult FoldIdentityTranspose(TransposeOp op) { if (!const_perm) return {}; auto const_value = const_perm.value(); - const auto &elements = const_value.getValues(); + const auto elements = const_value.getValues(); for (auto it : llvm::enumerate(elements)) { if (it.index() != it.value()) return {}; } + // TODO(jpienaar): Remove when we handle this more generally. + if (op.getType() != op.x().getType()) { + // If the types don't match then only fold if all the operands are in the TF + // dialect. + for (auto user : op.getOperation()->getUsers()) + if (user->getDialect() != op.getDialect()) return {}; + } + return op.x(); } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 160bba94cfc..cfe8db9025e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -3,8 +3,8 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} { // CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> { - // CHECK-NOT: tf.Cast - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: %[[RESULT:.*]] = "tf.AddV2" + // CHECK-SAME: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: return %[[RESULT]] : tensor<1xi32> %0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32> %1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32> @@ -60,8 +60,8 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @simple_folding func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor { -// CHECK: %[[CST:.*]] = "tf.Const"{{.*}} {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[CST]] +// CHECK: %[[SHAPE:.*]] = "tf.Shape" +// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[SHAPE]] // CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> // CHECK: return %[[CONV]] : tensor<1x1x1x1xf32> %0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32> @@ -300,13 +300,6 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %0 : tensor<*xi32> } - // CHECK-LABEL: func @fold_cast - func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-NOT: Cast - %0 = "tf.Cast"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>) - return %0 : tensor<*xf32> - } - // CHECK-LABEL: func @while_variant // CHECK-SAME: -> tensor>> func @while_variant(%arg0: tensor>>) -> tensor { @@ -362,8 +355,6 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @partitioned_call_func_const func @partitioned_call_func_const(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: return %[[CONST]] return %arg0 : tensor<2xi32> } @@ -410,4 +401,18 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { %40 = "tf.Reshape"(%39, %19) {T = f32, Tshape = i32, device = ""} : (tensor<1x4x4x32xf32>, tensor<2xi32>) -> tensor return } + + // CHECK-LABEL: const_fold + func @const_fold() -> () { + // CHECK: tf.Const + // CHECK-SAME: () -> tensor<4xi32> + %0 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32> + // CHECK: tf.Const + // CHECK-SAME: () -> tensor<4xi32> + %1 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32> + // CHECK: tf.Add + // CHECK-SAME: (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %2 = "tf.Add"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + return + } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 5a2cae38062..6a63e83be0f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -430,6 +430,7 @@ LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, Attribute ComputeOutputComponent(const ValuePort& value_port, ValueQueryFn values) { LLVM_DEBUG(value_port.print(llvm::errs() << "\nComputing output for ")); + if (auto known = values(value_port)) return known; auto op = value_port.producer.dyn_cast(); if (!op) return nullptr; @@ -454,6 +455,7 @@ Attribute ComputeOutputComponent(const ValuePort& value_port, ValuePort op_port(op->getOperand(port[1])); return values(op_port); } + return nullptr; } @@ -475,8 +477,11 @@ class ShapeInference { } Attribute ComputeOutputComponent(const ValuePort& value_port) { - return ::mlir::TF::ComputeOutputComponent( + if (auto known_attr = results_[value_port]) return known_attr; + auto attr = ::mlir::TF::ComputeOutputComponent( value_port, [this](const ValuePort& port) { return results_[port]; }); + RecordValue(value_port, attr); + return attr; } // Returns ShapeHandle if the op result could be computed as shape. @@ -520,19 +525,35 @@ class ShapeInference { LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, int64_t max_iteration); + // Propagates any constant operand of call_op to the called function body's + // corresponding argument if the callee has only one use. + // + // TODO(b/154065712): Move this to a more general inter-procedural constant + // folding pass. + void PropagateConstantToCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, ModuleOp module); + + // Propagates any constant return value of the callee function to the call + // op's corresponding result. + void PropagateConstantFromCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, ModuleOp module); + + // Tries to compute the result of folding the op. This doesn't actually + // perform constant folding, it is just computes the equivalent constants. + // Returns whether it was able to compute constant values. + LogicalResult TryToFold(Operation* op); + private: // Mapping between ValuePort (which corresponds to an OpResult or smaller, // e.g., first element of OpResult produded) to an Attribute if the ValuePort // corresponds to a constant value. ValuePortResultMap results_; int64_t graph_version_; - MLIRContext* context_; Dialect* tf_dialect_; }; ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context) : graph_version_(graph_version) { - context_ = context; tf_dialect_ = context->getRegisteredDialect(); } @@ -581,7 +602,6 @@ ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, auto ret = ComputeOutputComponent(front); if (!ret) continue; - RecordValue(front, ret); LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = ")); // If worklist is empty, then this is the root query op. @@ -686,10 +706,14 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { size_t index = it.index(); // If the operand is constant, then convert it to Tensor. - ElementsAttr attr; - if (matchPattern(operand, m_Constant(&attr))) { + ValuePort vp(operand); + Attribute attr = ComputeOutputComponent(vp); + if (!attr && matchPattern(operand, m_Constant(&attr))) + RecordValue(vp, attr); + if (attr) { tensorflow::Tensor* input_tensor = &tensors[index]; - auto status = tensorflow::ConvertToTensor(attr, input_tensor); + auto status = + tensorflow::ConvertToTensor(attr.cast(), input_tensor); if (status.ok()) { input_tensors[index] = input_tensor; } else { @@ -865,13 +889,9 @@ LogicalResult ShapeInference::PropagateShapeToFunctions( return success(all_succeeded); } -// If the callee has only one use, propagates any constant operand of call_op to -// the called function body's corresponding argument. -// -// TODO(b/154065712): Move this to a more general inter-procedural constant -// folding pass. -void PropagateConstantToCallee(CallOpInterface call_op, - SymbolRefAttr callee_sym, ModuleOp module) { +void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, + ModuleOp module) { auto func = module.lookupSymbol(callee_sym.getRootReference()); auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); int num_uses = std::distance(func_uses->begin(), func_uses->end()); @@ -879,31 +899,29 @@ void PropagateConstantToCallee(CallOpInterface call_op, Operation* op = call_op.getOperation(); if (num_uses == 1) { // If this is the only caller, and an operand is a constant, propagate - // the constant inside the function. + // the constant value inside the function. for (auto arg : func.getArguments()) { - auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp(); - if (isa_and_nonnull(operand)) { - arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0)); - } + auto operand = op->getOperand(arg.getArgNumber()); + if (auto known_constant = ComputeOutputComponent(ValuePort(operand))) + RecordValue(ValuePort(arg), known_constant); } } } -// Propagates any constant return value of the callee function to the call op's -// corresponding result. -void PropagateConstantFromCallee(CallOpInterface call_op, - SymbolRefAttr callee_sym, ModuleOp module) { +void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, + ModuleOp module) { auto func = module.lookupSymbol(callee_sym.getRootReference()); - // If the return value is a constant, replace the call result with a constant. + // If the return value is a constant, use the constant as the value of + // the call return. Operation* op = call_op.getOperation(); OpBuilder builder(op); builder.setInsertionPointAfter(op); for (auto retval : llvm::enumerate(func.front().getTerminator()->getOperands())) { - auto retval_op = retval.value().getDefiningOp(); - if (isa_and_nonnull(retval_op)) { - op->getResult(retval.index()) - .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0)); + ValuePort vp(retval.value()); + if (auto known_constant = ComputeOutputComponent(vp)) { + RecordValue(ValuePort(op->getResult(retval.index())), known_constant); } } } @@ -938,10 +956,68 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions( return success(); } +LogicalResult ShapeInference::TryToFold(Operation* op) { + // If any output result is known, then the op probably has been computed + // before. + if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))]) + return success(); + + SmallVector constant_operands(op->getNumOperands()); + SmallVector fold_results; + + // Check to see if any operands to the operation is constant and whether + // the operation knows how to constant fold itself. + bool some_unknown = false; + for (int i = 0, e = op->getNumOperands(); i != e; ++i) { + if (!(constant_operands[i] = + ComputeOutputComponent(ValuePort(op->getOperand(i))))) + some_unknown = true; + } + + // Attempt to constant fold the operation. + auto* abstract_op = op->getAbstractOperation(); + if (abstract_op) { + if (failed(abstract_op->foldHook(op, constant_operands, fold_results))) + return failure(); + } else { + Dialect* dialect = op->getDialect(); + if (!dialect) return failure(); + // Only attempt TF dialect fallback if there are no unknown operands. + if (some_unknown && dialect == tf_dialect_) return failure(); + SmallVector constants; + if (failed(dialect->constantFoldHook(op, constant_operands, constants))) + return failure(); + fold_results.assign(constants.begin(), constants.end()); + } + + for (auto result : zip(op->getResults(), fold_results)) { + auto fold_result = std::get<1>(result); + Attribute attr = nullptr; + if ((attr = fold_result.dyn_cast())) { + RecordValue(ValuePort(std::get<0>(result)), attr); + } else { + auto value = fold_result.get(); + if ((attr = ComputeOutputComponent(ValuePort(value)))) + RecordValue(ValuePort(std::get<0>(result)), attr); + } + + if (ElementsAttr eattr = attr.dyn_cast_or_null()) { + if (std::get<0>(result).getType() == eattr.getType()) continue; + + // Inserts a cast back to the original type if any user is not in the + // TF dialect. + Type old_type = std::get<0>(result).getType(); + std::get<0>(result).setType(eattr.getType()); + AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), tf_dialect_, + old_type); + } + } + + return success(); +} + LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, int64_t max_iteration) { - // An operation folder that is used to attempt folding before inference._ - OperationFolder folder(context_); bool changed = true; // TODO(aminim): we could have a more efficient traversal by guiding the @@ -955,9 +1031,7 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, region->walk([&](Operation* op) { if (auto infer_ti = dyn_cast(op)) { changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_); - // TODO(jpienaar): Debug why we can't just return here. We end up with - // additional constant due to the propagation of constant into attached - // function if we return already. + return; } if (op->getDialect() != tf_dialect_) { @@ -965,8 +1039,9 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, return; } - // Before attempting inference, just try to fold the operation. - if (succeeded(folder.tryToFold(op))) return; + // Before attempting inference, just try to compute the folded + // value/shape. + if (succeeded(TryToFold(op))) return; // Best-effort shape inference in attached functions. Do not return // failure even if it doesn't get to fixed point.