Internal change

PiperOrigin-RevId: 311558265
Change-Id: Ib91edbfdbd7d3442c72401a794283518393bc64d
This commit is contained in:
A. Unique TensorFlower 2020-05-14 10:24:35 -07:00 committed by TensorFlower Gardener
parent 9dd3efb5aa
commit c3d351abd2
3 changed files with 49 additions and 137 deletions

View File

@ -3551,20 +3551,12 @@ OpFoldResult FoldIdentityTranspose(TransposeOp op) {
if (!const_perm) return {}; if (!const_perm) return {};
auto const_value = const_perm.value(); auto const_value = const_perm.value();
const auto elements = const_value.getValues<APInt>(); const auto &elements = const_value.getValues<APInt>();
for (auto it : llvm::enumerate(elements)) { for (auto it : llvm::enumerate(elements)) {
if (it.index() != it.value()) return {}; if (it.index() != it.value()) return {};
} }
// TODO(jpienaar): Remove when we handle this more generally.
if (op.getType() != op.x().getType()) {
// If the types don't match then only fold if all the operands are in the TF
// dialect.
for (auto user : op.getOperation()->getUsers())
if (user->getDialect() != op.getDialect()) return {};
}
return op.x(); return op.x();
} }

View File

@ -3,8 +3,8 @@
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} { module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} {
// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> // CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32>
func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> { func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> {
// CHECK: %[[RESULT:.*]] = "tf.AddV2" // CHECK-NOT: tf.Cast
// CHECK-SAME: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: return %[[RESULT]] : tensor<1xi32> // CHECK: return %[[RESULT]] : tensor<1xi32>
%0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32> %0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32>
%1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32> %1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32>
@ -60,8 +60,8 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
// CHECK-LABEL: func @simple_folding // CHECK-LABEL: func @simple_folding
func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor<?x?x?x?xf32> { func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor<?x?x?x?xf32> {
// CHECK: %[[SHAPE:.*]] = "tf.Shape" // CHECK: %[[CST:.*]] = "tf.Const"{{.*}} {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[SHAPE]] // CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[CST]]
// CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> // CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
// CHECK: return %[[CONV]] : tensor<1x1x1x1xf32> // CHECK: return %[[CONV]] : tensor<1x1x1x1xf32>
%0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32> %0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32>
@ -300,6 +300,13 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
return %0 : tensor<*xi32> return %0 : tensor<*xi32>
} }
// CHECK-LABEL: func @fold_cast
func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NOT: Cast
%0 = "tf.Cast"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>)
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @while_variant // CHECK-LABEL: func @while_variant
// CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>> // CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>>
func @while_variant(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant> { func @while_variant(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant> {
@ -355,6 +362,8 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
// CHECK-LABEL: func @partitioned_call_func_const // CHECK-LABEL: func @partitioned_call_func_const
func @partitioned_call_func_const(%arg0: tensor<2xi32>) -> tensor<2xi32> { func @partitioned_call_func_const(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: return %[[CONST]]
return %arg0 : tensor<2xi32> return %arg0 : tensor<2xi32>
} }
@ -401,18 +410,4 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
%40 = "tf.Reshape"(%39, %19) {T = f32, Tshape = i32, device = ""} : (tensor<1x4x4x32xf32>, tensor<2xi32>) -> tensor<?x?xf32> %40 = "tf.Reshape"(%39, %19) {T = f32, Tshape = i32, device = ""} : (tensor<1x4x4x32xf32>, tensor<2xi32>) -> tensor<?x?xf32>
return return
} }
// CHECK-LABEL: const_fold
func @const_fold() -> () {
// CHECK: tf.Const
// CHECK-SAME: () -> tensor<4xi32>
%0 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32>
// CHECK: tf.Const
// CHECK-SAME: () -> tensor<4xi32>
%1 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32>
// CHECK: tf.Add
// CHECK-SAME: (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%2 = "tf.Add"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
return
}
} }

View File

@ -430,7 +430,6 @@ LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
Attribute ComputeOutputComponent(const ValuePort& value_port, Attribute ComputeOutputComponent(const ValuePort& value_port,
ValueQueryFn values) { ValueQueryFn values) {
LLVM_DEBUG(value_port.print(llvm::errs() << "\nComputing output for ")); LLVM_DEBUG(value_port.print(llvm::errs() << "\nComputing output for "));
if (auto known = values(value_port)) return known;
auto op = value_port.producer.dyn_cast<Operation*>(); auto op = value_port.producer.dyn_cast<Operation*>();
if (!op) return nullptr; if (!op) return nullptr;
@ -455,7 +454,6 @@ Attribute ComputeOutputComponent(const ValuePort& value_port,
ValuePort op_port(op->getOperand(port[1])); ValuePort op_port(op->getOperand(port[1]));
return values(op_port); return values(op_port);
} }
return nullptr; return nullptr;
} }
@ -477,11 +475,8 @@ class ShapeInference {
} }
Attribute ComputeOutputComponent(const ValuePort& value_port) { Attribute ComputeOutputComponent(const ValuePort& value_port) {
if (auto known_attr = results_[value_port]) return known_attr; return ::mlir::TF::ComputeOutputComponent(
auto attr = ::mlir::TF::ComputeOutputComponent(
value_port, [this](const ValuePort& port) { return results_[port]; }); value_port, [this](const ValuePort& port) { return results_[port]; });
RecordValue(value_port, attr);
return attr;
} }
// Returns ShapeHandle if the op result could be computed as shape. // Returns ShapeHandle if the op result could be computed as shape.
@ -525,35 +520,19 @@ class ShapeInference {
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
int64_t max_iteration); int64_t max_iteration);
// Propagates any constant operand of call_op to the called function body's
// corresponding argument if the callee has only one use.
//
// TODO(b/154065712): Move this to a more general inter-procedural constant
// folding pass.
void PropagateConstantToCallee(CallOpInterface call_op,
SymbolRefAttr callee_sym, ModuleOp module);
// Propagates any constant return value of the callee function to the call
// op's corresponding result.
void PropagateConstantFromCallee(CallOpInterface call_op,
SymbolRefAttr callee_sym, ModuleOp module);
// Tries to compute the result of folding the op. This doesn't actually
// perform constant folding, it is just computes the equivalent constants.
// Returns whether it was able to compute constant values.
LogicalResult TryToFold(Operation* op);
private: private:
// Mapping between ValuePort (which corresponds to an OpResult or smaller, // Mapping between ValuePort (which corresponds to an OpResult or smaller,
// e.g., first element of OpResult produded) to an Attribute if the ValuePort // e.g., first element of OpResult produded) to an Attribute if the ValuePort
// corresponds to a constant value. // corresponds to a constant value.
ValuePortResultMap results_; ValuePortResultMap results_;
int64_t graph_version_; int64_t graph_version_;
MLIRContext* context_;
Dialect* tf_dialect_; Dialect* tf_dialect_;
}; };
ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context) ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context)
: graph_version_(graph_version) { : graph_version_(graph_version) {
context_ = context;
tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>(); tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>();
} }
@ -602,6 +581,7 @@ ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
auto ret = ComputeOutputComponent(front); auto ret = ComputeOutputComponent(front);
if (!ret) continue; if (!ret) continue;
RecordValue(front, ret);
LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = ")); LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
// If worklist is empty, then this is the root query op. // If worklist is empty, then this is the root query op.
@ -706,14 +686,10 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
size_t index = it.index(); size_t index = it.index();
// If the operand is constant, then convert it to Tensor. // If the operand is constant, then convert it to Tensor.
ValuePort vp(operand); ElementsAttr attr;
Attribute attr = ComputeOutputComponent(vp); if (matchPattern(operand, m_Constant(&attr))) {
if (!attr && matchPattern(operand, m_Constant(&attr)))
RecordValue(vp, attr);
if (attr) {
tensorflow::Tensor* input_tensor = &tensors[index]; tensorflow::Tensor* input_tensor = &tensors[index];
auto status = auto status = tensorflow::ConvertToTensor(attr, input_tensor);
tensorflow::ConvertToTensor(attr.cast<ElementsAttr>(), input_tensor);
if (status.ok()) { if (status.ok()) {
input_tensors[index] = input_tensor; input_tensors[index] = input_tensor;
} else { } else {
@ -889,9 +865,13 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
return success(all_succeeded); return success(all_succeeded);
} }
void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op, // If the callee has only one use, propagates any constant operand of call_op to
SymbolRefAttr callee_sym, // the called function body's corresponding argument.
ModuleOp module) { //
// TODO(b/154065712): Move this to a more general inter-procedural constant
// folding pass.
void PropagateConstantToCallee(CallOpInterface call_op,
SymbolRefAttr callee_sym, ModuleOp module) {
auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference()); auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference());
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
int num_uses = std::distance(func_uses->begin(), func_uses->end()); int num_uses = std::distance(func_uses->begin(), func_uses->end());
@ -899,29 +879,31 @@ void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op,
Operation* op = call_op.getOperation(); Operation* op = call_op.getOperation();
if (num_uses == 1) { if (num_uses == 1) {
// If this is the only caller, and an operand is a constant, propagate // If this is the only caller, and an operand is a constant, propagate
// the constant value inside the function. // the constant inside the function.
for (auto arg : func.getArguments()) { for (auto arg : func.getArguments()) {
auto operand = op->getOperand(arg.getArgNumber()); auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp();
if (auto known_constant = ComputeOutputComponent(ValuePort(operand))) if (isa_and_nonnull<TF::ConstOp>(operand)) {
RecordValue(ValuePort(arg), known_constant); arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0));
}
} }
} }
} }
void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op, // Propagates any constant return value of the callee function to the call op's
SymbolRefAttr callee_sym, // corresponding result.
ModuleOp module) { void PropagateConstantFromCallee(CallOpInterface call_op,
SymbolRefAttr callee_sym, ModuleOp module) {
auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference()); auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference());
// If the return value is a constant, use the constant as the value of // If the return value is a constant, replace the call result with a constant.
// the call return.
Operation* op = call_op.getOperation(); Operation* op = call_op.getOperation();
OpBuilder builder(op); OpBuilder builder(op);
builder.setInsertionPointAfter(op); builder.setInsertionPointAfter(op);
for (auto retval : for (auto retval :
llvm::enumerate(func.front().getTerminator()->getOperands())) { llvm::enumerate(func.front().getTerminator()->getOperands())) {
ValuePort vp(retval.value()); auto retval_op = retval.value().getDefiningOp();
if (auto known_constant = ComputeOutputComponent(vp)) { if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
RecordValue(ValuePort(op->getResult(retval.index())), known_constant); op->getResult(retval.index())
.replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
} }
} }
} }
@ -956,68 +938,10 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
return success(); return success();
} }
LogicalResult ShapeInference::TryToFold(Operation* op) {
// If any output result is known, then the op probably has been computed
// before.
if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))])
return success();
SmallVector<Attribute, 8> constant_operands(op->getNumOperands());
SmallVector<OpFoldResult, 8> fold_results;
// Check to see if any operands to the operation is constant and whether
// the operation knows how to constant fold itself.
bool some_unknown = false;
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
if (!(constant_operands[i] =
ComputeOutputComponent(ValuePort(op->getOperand(i)))))
some_unknown = true;
}
// Attempt to constant fold the operation.
auto* abstract_op = op->getAbstractOperation();
if (abstract_op) {
if (failed(abstract_op->foldHook(op, constant_operands, fold_results)))
return failure();
} else {
Dialect* dialect = op->getDialect();
if (!dialect) return failure();
// Only attempt TF dialect fallback if there are no unknown operands.
if (some_unknown && dialect == tf_dialect_) return failure();
SmallVector<Attribute, 8> constants;
if (failed(dialect->constantFoldHook(op, constant_operands, constants)))
return failure();
fold_results.assign(constants.begin(), constants.end());
}
for (auto result : zip(op->getResults(), fold_results)) {
auto fold_result = std::get<1>(result);
Attribute attr = nullptr;
if ((attr = fold_result.dyn_cast<Attribute>())) {
RecordValue(ValuePort(std::get<0>(result)), attr);
} else {
auto value = fold_result.get<Value>();
if ((attr = ComputeOutputComponent(ValuePort(value))))
RecordValue(ValuePort(std::get<0>(result)), attr);
}
if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) {
if (std::get<0>(result).getType() == eattr.getType()) continue;
// Inserts a cast back to the original type if any user is not in the
// TF dialect.
Type old_type = std::get<0>(result).getType();
std::get<0>(result).setType(eattr.getType());
AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), tf_dialect_,
old_type);
}
}
return success();
}
LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
int64_t max_iteration) { int64_t max_iteration) {
// An operation folder that is used to attempt folding before inference._
OperationFolder folder(context_);
bool changed = true; bool changed = true;
// TODO(aminim): we could have a more efficient traversal by guiding the // TODO(aminim): we could have a more efficient traversal by guiding the
@ -1031,7 +955,9 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
region->walk([&](Operation* op) { region->walk([&](Operation* op) {
if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) { if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_); changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_);
return; // TODO(jpienaar): Debug why we can't just return here. We end up with
// additional constant due to the propagation of constant into attached
// function if we return already.
} }
if (op->getDialect() != tf_dialect_) { if (op->getDialect() != tf_dialect_) {
@ -1039,9 +965,8 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
return; return;
} }
// Before attempting inference, just try to compute the folded // Before attempting inference, just try to fold the operation.
// value/shape. if (succeeded(folder.tryToFold(op))) return;
if (succeeded(TryToFold(op))) return;
// Best-effort shape inference in attached functions. Do not return // Best-effort shape inference in attached functions. Do not return
// failure even if it doesn't get to fixed point. // failure even if it doesn't get to fixed point.