From c3d351abd20a814e7a8eae4e3d951b18667cbac8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 14 May 2020 10:24:35 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 311558265 Change-Id: Ib91edbfdbd7d3442c72401a794283518393bc64d --- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 10 +- .../tensorflow/tests/shape_inference.mlir | 31 ++-- .../tensorflow/transforms/shape_inference.cc | 145 +++++------------- 3 files changed, 49 insertions(+), 137 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index b21fef32cca..2007824369c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -3551,20 +3551,12 @@ OpFoldResult FoldIdentityTranspose(TransposeOp op) { if (!const_perm) return {}; auto const_value = const_perm.value(); - const auto elements = const_value.getValues(); + const auto &elements = const_value.getValues(); for (auto it : llvm::enumerate(elements)) { if (it.index() != it.value()) return {}; } - // TODO(jpienaar): Remove when we handle this more generally. - if (op.getType() != op.x().getType()) { - // If the types don't match then only fold if all the operands are in the TF - // dialect. - for (auto user : op.getOperation()->getUsers()) - if (user->getDialect() != op.getDialect()) return {}; - } - return op.x(); } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index cfe8db9025e..160bba94cfc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -3,8 +3,8 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} { // CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> { - // CHECK: %[[RESULT:.*]] = "tf.AddV2" - // CHECK-SAME: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK-NOT: tf.Cast + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: return %[[RESULT]] : tensor<1xi32> %0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32> %1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32> @@ -60,8 +60,8 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @simple_folding func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor { -// CHECK: %[[SHAPE:.*]] = "tf.Shape" -// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[SHAPE]] +// CHECK: %[[CST:.*]] = "tf.Const"{{.*}} {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[CST]] // CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> // CHECK: return %[[CONV]] : tensor<1x1x1x1xf32> %0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32> @@ -300,6 +300,13 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %0 : tensor<*xi32> } + // CHECK-LABEL: func @fold_cast + func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK-NOT: Cast + %0 = "tf.Cast"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>) + return %0 : tensor<*xf32> + } + // CHECK-LABEL: func @while_variant // CHECK-SAME: -> tensor>> func @while_variant(%arg0: tensor>>) -> tensor { @@ -355,6 +362,8 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @partitioned_call_func_const func @partitioned_call_func_const(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: return %[[CONST]] return %arg0 : tensor<2xi32> } @@ -401,18 +410,4 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { %40 = "tf.Reshape"(%39, %19) {T = f32, Tshape = i32, device = ""} : (tensor<1x4x4x32xf32>, tensor<2xi32>) -> tensor return } - - // CHECK-LABEL: const_fold - func @const_fold() -> () { - // CHECK: tf.Const - // CHECK-SAME: () -> tensor<4xi32> - %0 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32> - // CHECK: tf.Const - // CHECK-SAME: () -> tensor<4xi32> - %1 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32> - // CHECK: tf.Add - // CHECK-SAME: (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - %2 = "tf.Add"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - return - } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 6a63e83be0f..5a2cae38062 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -430,7 +430,6 @@ LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, Attribute ComputeOutputComponent(const ValuePort& value_port, ValueQueryFn values) { LLVM_DEBUG(value_port.print(llvm::errs() << "\nComputing output for ")); - if (auto known = values(value_port)) return known; auto op = value_port.producer.dyn_cast(); if (!op) return nullptr; @@ -455,7 +454,6 @@ Attribute ComputeOutputComponent(const ValuePort& value_port, ValuePort op_port(op->getOperand(port[1])); return values(op_port); } - return nullptr; } @@ -477,11 +475,8 @@ class ShapeInference { } Attribute ComputeOutputComponent(const ValuePort& value_port) { - if (auto known_attr = results_[value_port]) return known_attr; - auto attr = ::mlir::TF::ComputeOutputComponent( + return ::mlir::TF::ComputeOutputComponent( value_port, [this](const ValuePort& port) { return results_[port]; }); - RecordValue(value_port, attr); - return attr; } // Returns ShapeHandle if the op result could be computed as shape. @@ -525,35 +520,19 @@ class ShapeInference { LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, int64_t max_iteration); - // Propagates any constant operand of call_op to the called function body's - // corresponding argument if the callee has only one use. - // - // TODO(b/154065712): Move this to a more general inter-procedural constant - // folding pass. - void PropagateConstantToCallee(CallOpInterface call_op, - SymbolRefAttr callee_sym, ModuleOp module); - - // Propagates any constant return value of the callee function to the call - // op's corresponding result. - void PropagateConstantFromCallee(CallOpInterface call_op, - SymbolRefAttr callee_sym, ModuleOp module); - - // Tries to compute the result of folding the op. This doesn't actually - // perform constant folding, it is just computes the equivalent constants. - // Returns whether it was able to compute constant values. - LogicalResult TryToFold(Operation* op); - private: // Mapping between ValuePort (which corresponds to an OpResult or smaller, // e.g., first element of OpResult produded) to an Attribute if the ValuePort // corresponds to a constant value. ValuePortResultMap results_; int64_t graph_version_; + MLIRContext* context_; Dialect* tf_dialect_; }; ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context) : graph_version_(graph_version) { + context_ = context; tf_dialect_ = context->getRegisteredDialect(); } @@ -602,6 +581,7 @@ ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, auto ret = ComputeOutputComponent(front); if (!ret) continue; + RecordValue(front, ret); LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = ")); // If worklist is empty, then this is the root query op. @@ -706,14 +686,10 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { size_t index = it.index(); // If the operand is constant, then convert it to Tensor. - ValuePort vp(operand); - Attribute attr = ComputeOutputComponent(vp); - if (!attr && matchPattern(operand, m_Constant(&attr))) - RecordValue(vp, attr); - if (attr) { + ElementsAttr attr; + if (matchPattern(operand, m_Constant(&attr))) { tensorflow::Tensor* input_tensor = &tensors[index]; - auto status = - tensorflow::ConvertToTensor(attr.cast(), input_tensor); + auto status = tensorflow::ConvertToTensor(attr, input_tensor); if (status.ok()) { input_tensors[index] = input_tensor; } else { @@ -889,9 +865,13 @@ LogicalResult ShapeInference::PropagateShapeToFunctions( return success(all_succeeded); } -void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op, - SymbolRefAttr callee_sym, - ModuleOp module) { +// If the callee has only one use, propagates any constant operand of call_op to +// the called function body's corresponding argument. +// +// TODO(b/154065712): Move this to a more general inter-procedural constant +// folding pass. +void PropagateConstantToCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, ModuleOp module) { auto func = module.lookupSymbol(callee_sym.getRootReference()); auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); int num_uses = std::distance(func_uses->begin(), func_uses->end()); @@ -899,29 +879,31 @@ void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op, Operation* op = call_op.getOperation(); if (num_uses == 1) { // If this is the only caller, and an operand is a constant, propagate - // the constant value inside the function. + // the constant inside the function. for (auto arg : func.getArguments()) { - auto operand = op->getOperand(arg.getArgNumber()); - if (auto known_constant = ComputeOutputComponent(ValuePort(operand))) - RecordValue(ValuePort(arg), known_constant); + auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp(); + if (isa_and_nonnull(operand)) { + arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0)); + } } } } -void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op, - SymbolRefAttr callee_sym, - ModuleOp module) { +// Propagates any constant return value of the callee function to the call op's +// corresponding result. +void PropagateConstantFromCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, ModuleOp module) { auto func = module.lookupSymbol(callee_sym.getRootReference()); - // If the return value is a constant, use the constant as the value of - // the call return. + // If the return value is a constant, replace the call result with a constant. Operation* op = call_op.getOperation(); OpBuilder builder(op); builder.setInsertionPointAfter(op); for (auto retval : llvm::enumerate(func.front().getTerminator()->getOperands())) { - ValuePort vp(retval.value()); - if (auto known_constant = ComputeOutputComponent(vp)) { - RecordValue(ValuePort(op->getResult(retval.index())), known_constant); + auto retval_op = retval.value().getDefiningOp(); + if (isa_and_nonnull(retval_op)) { + op->getResult(retval.index()) + .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0)); } } } @@ -956,68 +938,10 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions( return success(); } -LogicalResult ShapeInference::TryToFold(Operation* op) { - // If any output result is known, then the op probably has been computed - // before. - if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))]) - return success(); - - SmallVector constant_operands(op->getNumOperands()); - SmallVector fold_results; - - // Check to see if any operands to the operation is constant and whether - // the operation knows how to constant fold itself. - bool some_unknown = false; - for (int i = 0, e = op->getNumOperands(); i != e; ++i) { - if (!(constant_operands[i] = - ComputeOutputComponent(ValuePort(op->getOperand(i))))) - some_unknown = true; - } - - // Attempt to constant fold the operation. - auto* abstract_op = op->getAbstractOperation(); - if (abstract_op) { - if (failed(abstract_op->foldHook(op, constant_operands, fold_results))) - return failure(); - } else { - Dialect* dialect = op->getDialect(); - if (!dialect) return failure(); - // Only attempt TF dialect fallback if there are no unknown operands. - if (some_unknown && dialect == tf_dialect_) return failure(); - SmallVector constants; - if (failed(dialect->constantFoldHook(op, constant_operands, constants))) - return failure(); - fold_results.assign(constants.begin(), constants.end()); - } - - for (auto result : zip(op->getResults(), fold_results)) { - auto fold_result = std::get<1>(result); - Attribute attr = nullptr; - if ((attr = fold_result.dyn_cast())) { - RecordValue(ValuePort(std::get<0>(result)), attr); - } else { - auto value = fold_result.get(); - if ((attr = ComputeOutputComponent(ValuePort(value)))) - RecordValue(ValuePort(std::get<0>(result)), attr); - } - - if (ElementsAttr eattr = attr.dyn_cast_or_null()) { - if (std::get<0>(result).getType() == eattr.getType()) continue; - - // Inserts a cast back to the original type if any user is not in the - // TF dialect. - Type old_type = std::get<0>(result).getType(); - std::get<0>(result).setType(eattr.getType()); - AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), tf_dialect_, - old_type); - } - } - - return success(); -} - LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, int64_t max_iteration) { + // An operation folder that is used to attempt folding before inference._ + OperationFolder folder(context_); bool changed = true; // TODO(aminim): we could have a more efficient traversal by guiding the @@ -1031,7 +955,9 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, region->walk([&](Operation* op) { if (auto infer_ti = dyn_cast(op)) { changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_); - return; + // TODO(jpienaar): Debug why we can't just return here. We end up with + // additional constant due to the propagation of constant into attached + // function if we return already. } if (op->getDialect() != tf_dialect_) { @@ -1039,9 +965,8 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, return; } - // Before attempting inference, just try to compute the folded - // value/shape. - if (succeeded(TryToFold(op))) return; + // Before attempting inference, just try to fold the operation. + if (succeeded(folder.tryToFold(op))) return; // Best-effort shape inference in attached functions. Do not return // failure even if it doesn't get to fixed point.