Internal change
PiperOrigin-RevId: 311558265 Change-Id: Ib91edbfdbd7d3442c72401a794283518393bc64d
This commit is contained in:
parent
9dd3efb5aa
commit
c3d351abd2
|
@ -3551,20 +3551,12 @@ OpFoldResult FoldIdentityTranspose(TransposeOp op) {
|
|||
if (!const_perm) return {};
|
||||
|
||||
auto const_value = const_perm.value();
|
||||
const auto elements = const_value.getValues<APInt>();
|
||||
const auto &elements = const_value.getValues<APInt>();
|
||||
|
||||
for (auto it : llvm::enumerate(elements)) {
|
||||
if (it.index() != it.value()) return {};
|
||||
}
|
||||
|
||||
// TODO(jpienaar): Remove when we handle this more generally.
|
||||
if (op.getType() != op.x().getType()) {
|
||||
// If the types don't match then only fold if all the operands are in the TF
|
||||
// dialect.
|
||||
for (auto user : op.getOperation()->getUsers())
|
||||
if (user->getDialect() != op.getDialect()) return {};
|
||||
}
|
||||
|
||||
return op.x();
|
||||
}
|
||||
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} {
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32>
|
||||
func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> {
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"
|
||||
// CHECK-SAME: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK-NOT: tf.Cast
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: return %[[RESULT]] : tensor<1xi32>
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32>
|
||||
|
@ -60,8 +60,8 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
|||
|
||||
// CHECK-LABEL: func @simple_folding
|
||||
func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor<?x?x?x?xf32> {
|
||||
// CHECK: %[[SHAPE:.*]] = "tf.Shape"
|
||||
// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[SHAPE]]
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"{{.*}} {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[CST]]
|
||||
// CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
|
||||
// CHECK: return %[[CONV]] : tensor<1x1x1x1xf32>
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32>
|
||||
|
@ -300,6 +300,13 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
|||
return %0 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_cast
|
||||
func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NOT: Cast
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>)
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @while_variant
|
||||
// CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>>
|
||||
func @while_variant(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant> {
|
||||
|
@ -355,6 +362,8 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
|||
|
||||
// CHECK-LABEL: func @partitioned_call_func_const
|
||||
func @partitioned_call_func_const(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: return %[[CONST]]
|
||||
return %arg0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
|
@ -401,18 +410,4 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
|||
%40 = "tf.Reshape"(%39, %19) {T = f32, Tshape = i32, device = ""} : (tensor<1x4x4x32xf32>, tensor<2xi32>) -> tensor<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: const_fold
|
||||
func @const_fold() -> () {
|
||||
// CHECK: tf.Const
|
||||
// CHECK-SAME: () -> tensor<4xi32>
|
||||
%0 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32>
|
||||
// CHECK: tf.Const
|
||||
// CHECK-SAME: () -> tensor<4xi32>
|
||||
%1 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32>
|
||||
// CHECK: tf.Add
|
||||
// CHECK-SAME: (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%2 = "tf.Add"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -430,7 +430,6 @@ LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
|
|||
Attribute ComputeOutputComponent(const ValuePort& value_port,
|
||||
ValueQueryFn values) {
|
||||
LLVM_DEBUG(value_port.print(llvm::errs() << "\nComputing output for "));
|
||||
if (auto known = values(value_port)) return known;
|
||||
|
||||
auto op = value_port.producer.dyn_cast<Operation*>();
|
||||
if (!op) return nullptr;
|
||||
|
@ -455,7 +454,6 @@ Attribute ComputeOutputComponent(const ValuePort& value_port,
|
|||
ValuePort op_port(op->getOperand(port[1]));
|
||||
return values(op_port);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -477,11 +475,8 @@ class ShapeInference {
|
|||
}
|
||||
|
||||
Attribute ComputeOutputComponent(const ValuePort& value_port) {
|
||||
if (auto known_attr = results_[value_port]) return known_attr;
|
||||
auto attr = ::mlir::TF::ComputeOutputComponent(
|
||||
return ::mlir::TF::ComputeOutputComponent(
|
||||
value_port, [this](const ValuePort& port) { return results_[port]; });
|
||||
RecordValue(value_port, attr);
|
||||
return attr;
|
||||
}
|
||||
|
||||
// Returns ShapeHandle if the op result could be computed as shape.
|
||||
|
@ -525,35 +520,19 @@ class ShapeInference {
|
|||
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
|
||||
int64_t max_iteration);
|
||||
|
||||
// Propagates any constant operand of call_op to the called function body's
|
||||
// corresponding argument if the callee has only one use.
|
||||
//
|
||||
// TODO(b/154065712): Move this to a more general inter-procedural constant
|
||||
// folding pass.
|
||||
void PropagateConstantToCallee(CallOpInterface call_op,
|
||||
SymbolRefAttr callee_sym, ModuleOp module);
|
||||
|
||||
// Propagates any constant return value of the callee function to the call
|
||||
// op's corresponding result.
|
||||
void PropagateConstantFromCallee(CallOpInterface call_op,
|
||||
SymbolRefAttr callee_sym, ModuleOp module);
|
||||
|
||||
// Tries to compute the result of folding the op. This doesn't actually
|
||||
// perform constant folding, it is just computes the equivalent constants.
|
||||
// Returns whether it was able to compute constant values.
|
||||
LogicalResult TryToFold(Operation* op);
|
||||
|
||||
private:
|
||||
// Mapping between ValuePort (which corresponds to an OpResult or smaller,
|
||||
// e.g., first element of OpResult produded) to an Attribute if the ValuePort
|
||||
// corresponds to a constant value.
|
||||
ValuePortResultMap results_;
|
||||
int64_t graph_version_;
|
||||
MLIRContext* context_;
|
||||
Dialect* tf_dialect_;
|
||||
};
|
||||
|
||||
ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context)
|
||||
: graph_version_(graph_version) {
|
||||
context_ = context;
|
||||
tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>();
|
||||
}
|
||||
|
||||
|
@ -602,6 +581,7 @@ ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
|
|||
auto ret = ComputeOutputComponent(front);
|
||||
if (!ret) continue;
|
||||
|
||||
RecordValue(front, ret);
|
||||
LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
|
||||
|
||||
// If worklist is empty, then this is the root query op.
|
||||
|
@ -706,14 +686,10 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
|||
size_t index = it.index();
|
||||
|
||||
// If the operand is constant, then convert it to Tensor.
|
||||
ValuePort vp(operand);
|
||||
Attribute attr = ComputeOutputComponent(vp);
|
||||
if (!attr && matchPattern(operand, m_Constant(&attr)))
|
||||
RecordValue(vp, attr);
|
||||
if (attr) {
|
||||
ElementsAttr attr;
|
||||
if (matchPattern(operand, m_Constant(&attr))) {
|
||||
tensorflow::Tensor* input_tensor = &tensors[index];
|
||||
auto status =
|
||||
tensorflow::ConvertToTensor(attr.cast<ElementsAttr>(), input_tensor);
|
||||
auto status = tensorflow::ConvertToTensor(attr, input_tensor);
|
||||
if (status.ok()) {
|
||||
input_tensors[index] = input_tensor;
|
||||
} else {
|
||||
|
@ -889,9 +865,13 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
|
|||
return success(all_succeeded);
|
||||
}
|
||||
|
||||
void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op,
|
||||
SymbolRefAttr callee_sym,
|
||||
ModuleOp module) {
|
||||
// If the callee has only one use, propagates any constant operand of call_op to
|
||||
// the called function body's corresponding argument.
|
||||
//
|
||||
// TODO(b/154065712): Move this to a more general inter-procedural constant
|
||||
// folding pass.
|
||||
void PropagateConstantToCallee(CallOpInterface call_op,
|
||||
SymbolRefAttr callee_sym, ModuleOp module) {
|
||||
auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference());
|
||||
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
|
||||
int num_uses = std::distance(func_uses->begin(), func_uses->end());
|
||||
|
@ -899,29 +879,31 @@ void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op,
|
|||
Operation* op = call_op.getOperation();
|
||||
if (num_uses == 1) {
|
||||
// If this is the only caller, and an operand is a constant, propagate
|
||||
// the constant value inside the function.
|
||||
// the constant inside the function.
|
||||
for (auto arg : func.getArguments()) {
|
||||
auto operand = op->getOperand(arg.getArgNumber());
|
||||
if (auto known_constant = ComputeOutputComponent(ValuePort(operand)))
|
||||
RecordValue(ValuePort(arg), known_constant);
|
||||
auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp();
|
||||
if (isa_and_nonnull<TF::ConstOp>(operand)) {
|
||||
arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op,
|
||||
SymbolRefAttr callee_sym,
|
||||
ModuleOp module) {
|
||||
// Propagates any constant return value of the callee function to the call op's
|
||||
// corresponding result.
|
||||
void PropagateConstantFromCallee(CallOpInterface call_op,
|
||||
SymbolRefAttr callee_sym, ModuleOp module) {
|
||||
auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference());
|
||||
// If the return value is a constant, use the constant as the value of
|
||||
// the call return.
|
||||
// If the return value is a constant, replace the call result with a constant.
|
||||
Operation* op = call_op.getOperation();
|
||||
OpBuilder builder(op);
|
||||
builder.setInsertionPointAfter(op);
|
||||
for (auto retval :
|
||||
llvm::enumerate(func.front().getTerminator()->getOperands())) {
|
||||
ValuePort vp(retval.value());
|
||||
if (auto known_constant = ComputeOutputComponent(vp)) {
|
||||
RecordValue(ValuePort(op->getResult(retval.index())), known_constant);
|
||||
auto retval_op = retval.value().getDefiningOp();
|
||||
if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
|
||||
op->getResult(retval.index())
|
||||
.replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -956,68 +938,10 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ShapeInference::TryToFold(Operation* op) {
|
||||
// If any output result is known, then the op probably has been computed
|
||||
// before.
|
||||
if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))])
|
||||
return success();
|
||||
|
||||
SmallVector<Attribute, 8> constant_operands(op->getNumOperands());
|
||||
SmallVector<OpFoldResult, 8> fold_results;
|
||||
|
||||
// Check to see if any operands to the operation is constant and whether
|
||||
// the operation knows how to constant fold itself.
|
||||
bool some_unknown = false;
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||
if (!(constant_operands[i] =
|
||||
ComputeOutputComponent(ValuePort(op->getOperand(i)))))
|
||||
some_unknown = true;
|
||||
}
|
||||
|
||||
// Attempt to constant fold the operation.
|
||||
auto* abstract_op = op->getAbstractOperation();
|
||||
if (abstract_op) {
|
||||
if (failed(abstract_op->foldHook(op, constant_operands, fold_results)))
|
||||
return failure();
|
||||
} else {
|
||||
Dialect* dialect = op->getDialect();
|
||||
if (!dialect) return failure();
|
||||
// Only attempt TF dialect fallback if there are no unknown operands.
|
||||
if (some_unknown && dialect == tf_dialect_) return failure();
|
||||
SmallVector<Attribute, 8> constants;
|
||||
if (failed(dialect->constantFoldHook(op, constant_operands, constants)))
|
||||
return failure();
|
||||
fold_results.assign(constants.begin(), constants.end());
|
||||
}
|
||||
|
||||
for (auto result : zip(op->getResults(), fold_results)) {
|
||||
auto fold_result = std::get<1>(result);
|
||||
Attribute attr = nullptr;
|
||||
if ((attr = fold_result.dyn_cast<Attribute>())) {
|
||||
RecordValue(ValuePort(std::get<0>(result)), attr);
|
||||
} else {
|
||||
auto value = fold_result.get<Value>();
|
||||
if ((attr = ComputeOutputComponent(ValuePort(value))))
|
||||
RecordValue(ValuePort(std::get<0>(result)), attr);
|
||||
}
|
||||
|
||||
if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) {
|
||||
if (std::get<0>(result).getType() == eattr.getType()) continue;
|
||||
|
||||
// Inserts a cast back to the original type if any user is not in the
|
||||
// TF dialect.
|
||||
Type old_type = std::get<0>(result).getType();
|
||||
std::get<0>(result).setType(eattr.getType());
|
||||
AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), tf_dialect_,
|
||||
old_type);
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
|
||||
int64_t max_iteration) {
|
||||
// An operation folder that is used to attempt folding before inference._
|
||||
OperationFolder folder(context_);
|
||||
bool changed = true;
|
||||
|
||||
// TODO(aminim): we could have a more efficient traversal by guiding the
|
||||
|
@ -1031,7 +955,9 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
|
|||
region->walk([&](Operation* op) {
|
||||
if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
|
||||
changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_);
|
||||
return;
|
||||
// TODO(jpienaar): Debug why we can't just return here. We end up with
|
||||
// additional constant due to the propagation of constant into attached
|
||||
// function if we return already.
|
||||
}
|
||||
|
||||
if (op->getDialect() != tf_dialect_) {
|
||||
|
@ -1039,9 +965,8 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
|
|||
return;
|
||||
}
|
||||
|
||||
// Before attempting inference, just try to compute the folded
|
||||
// value/shape.
|
||||
if (succeeded(TryToFold(op))) return;
|
||||
// Before attempting inference, just try to fold the operation.
|
||||
if (succeeded(folder.tryToFold(op))) return;
|
||||
|
||||
// Best-effort shape inference in attached functions. Do not return
|
||||
// failure even if it doesn't get to fixed point.
|
||||
|
|
Loading…
Reference in New Issue