Internal change
PiperOrigin-RevId: 311558265 Change-Id: Ib91edbfdbd7d3442c72401a794283518393bc64d
This commit is contained in:
parent
9dd3efb5aa
commit
c3d351abd2
|
@ -3551,20 +3551,12 @@ OpFoldResult FoldIdentityTranspose(TransposeOp op) {
|
||||||
if (!const_perm) return {};
|
if (!const_perm) return {};
|
||||||
|
|
||||||
auto const_value = const_perm.value();
|
auto const_value = const_perm.value();
|
||||||
const auto elements = const_value.getValues<APInt>();
|
const auto &elements = const_value.getValues<APInt>();
|
||||||
|
|
||||||
for (auto it : llvm::enumerate(elements)) {
|
for (auto it : llvm::enumerate(elements)) {
|
||||||
if (it.index() != it.value()) return {};
|
if (it.index() != it.value()) return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jpienaar): Remove when we handle this more generally.
|
|
||||||
if (op.getType() != op.x().getType()) {
|
|
||||||
// If the types don't match then only fold if all the operands are in the TF
|
|
||||||
// dialect.
|
|
||||||
for (auto user : op.getOperation()->getUsers())
|
|
||||||
if (user->getDialect() != op.getDialect()) return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
return op.x();
|
return op.x();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,8 @@
|
||||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} {
|
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} {
|
||||||
// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32>
|
// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32>
|
||||||
func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> {
|
func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> {
|
||||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"
|
// CHECK-NOT: tf.Cast
|
||||||
// CHECK-SAME: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||||
// CHECK: return %[[RESULT]] : tensor<1xi32>
|
// CHECK: return %[[RESULT]] : tensor<1xi32>
|
||||||
%0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32>
|
%0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32>
|
||||||
%1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32>
|
%1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32>
|
||||||
|
@ -60,8 +60,8 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||||
|
|
||||||
// CHECK-LABEL: func @simple_folding
|
// CHECK-LABEL: func @simple_folding
|
||||||
func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor<?x?x?x?xf32> {
|
func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor<?x?x?x?xf32> {
|
||||||
// CHECK: %[[SHAPE:.*]] = "tf.Shape"
|
// CHECK: %[[CST:.*]] = "tf.Const"{{.*}} {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[SHAPE]]
|
// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[CST]]
|
||||||
// CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
|
// CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
|
||||||
// CHECK: return %[[CONV]] : tensor<1x1x1x1xf32>
|
// CHECK: return %[[CONV]] : tensor<1x1x1x1xf32>
|
||||||
%0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32>
|
%0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32>
|
||||||
|
@ -300,6 +300,13 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||||
return %0 : tensor<*xi32>
|
return %0 : tensor<*xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_cast
|
||||||
|
func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK-NOT: Cast
|
||||||
|
%0 = "tf.Cast"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>)
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @while_variant
|
// CHECK-LABEL: func @while_variant
|
||||||
// CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>>
|
// CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>>
|
||||||
func @while_variant(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant> {
|
func @while_variant(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant> {
|
||||||
|
@ -355,6 +362,8 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||||
|
|
||||||
// CHECK-LABEL: func @partitioned_call_func_const
|
// CHECK-LABEL: func @partitioned_call_func_const
|
||||||
func @partitioned_call_func_const(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
func @partitioned_call_func_const(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||||
|
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||||
|
// CHECK: return %[[CONST]]
|
||||||
return %arg0 : tensor<2xi32>
|
return %arg0 : tensor<2xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -401,18 +410,4 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||||
%40 = "tf.Reshape"(%39, %19) {T = f32, Tshape = i32, device = ""} : (tensor<1x4x4x32xf32>, tensor<2xi32>) -> tensor<?x?xf32>
|
%40 = "tf.Reshape"(%39, %19) {T = f32, Tshape = i32, device = ""} : (tensor<1x4x4x32xf32>, tensor<2xi32>) -> tensor<?x?xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: const_fold
|
|
||||||
func @const_fold() -> () {
|
|
||||||
// CHECK: tf.Const
|
|
||||||
// CHECK-SAME: () -> tensor<4xi32>
|
|
||||||
%0 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32>
|
|
||||||
// CHECK: tf.Const
|
|
||||||
// CHECK-SAME: () -> tensor<4xi32>
|
|
||||||
%1 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32>
|
|
||||||
// CHECK: tf.Add
|
|
||||||
// CHECK-SAME: (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
|
||||||
%2 = "tf.Add"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -430,7 +430,6 @@ LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
|
||||||
Attribute ComputeOutputComponent(const ValuePort& value_port,
|
Attribute ComputeOutputComponent(const ValuePort& value_port,
|
||||||
ValueQueryFn values) {
|
ValueQueryFn values) {
|
||||||
LLVM_DEBUG(value_port.print(llvm::errs() << "\nComputing output for "));
|
LLVM_DEBUG(value_port.print(llvm::errs() << "\nComputing output for "));
|
||||||
if (auto known = values(value_port)) return known;
|
|
||||||
|
|
||||||
auto op = value_port.producer.dyn_cast<Operation*>();
|
auto op = value_port.producer.dyn_cast<Operation*>();
|
||||||
if (!op) return nullptr;
|
if (!op) return nullptr;
|
||||||
|
@ -455,7 +454,6 @@ Attribute ComputeOutputComponent(const ValuePort& value_port,
|
||||||
ValuePort op_port(op->getOperand(port[1]));
|
ValuePort op_port(op->getOperand(port[1]));
|
||||||
return values(op_port);
|
return values(op_port);
|
||||||
}
|
}
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -477,11 +475,8 @@ class ShapeInference {
|
||||||
}
|
}
|
||||||
|
|
||||||
Attribute ComputeOutputComponent(const ValuePort& value_port) {
|
Attribute ComputeOutputComponent(const ValuePort& value_port) {
|
||||||
if (auto known_attr = results_[value_port]) return known_attr;
|
return ::mlir::TF::ComputeOutputComponent(
|
||||||
auto attr = ::mlir::TF::ComputeOutputComponent(
|
|
||||||
value_port, [this](const ValuePort& port) { return results_[port]; });
|
value_port, [this](const ValuePort& port) { return results_[port]; });
|
||||||
RecordValue(value_port, attr);
|
|
||||||
return attr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns ShapeHandle if the op result could be computed as shape.
|
// Returns ShapeHandle if the op result could be computed as shape.
|
||||||
|
@ -525,35 +520,19 @@ class ShapeInference {
|
||||||
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
|
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
|
||||||
int64_t max_iteration);
|
int64_t max_iteration);
|
||||||
|
|
||||||
// Propagates any constant operand of call_op to the called function body's
|
|
||||||
// corresponding argument if the callee has only one use.
|
|
||||||
//
|
|
||||||
// TODO(b/154065712): Move this to a more general inter-procedural constant
|
|
||||||
// folding pass.
|
|
||||||
void PropagateConstantToCallee(CallOpInterface call_op,
|
|
||||||
SymbolRefAttr callee_sym, ModuleOp module);
|
|
||||||
|
|
||||||
// Propagates any constant return value of the callee function to the call
|
|
||||||
// op's corresponding result.
|
|
||||||
void PropagateConstantFromCallee(CallOpInterface call_op,
|
|
||||||
SymbolRefAttr callee_sym, ModuleOp module);
|
|
||||||
|
|
||||||
// Tries to compute the result of folding the op. This doesn't actually
|
|
||||||
// perform constant folding, it is just computes the equivalent constants.
|
|
||||||
// Returns whether it was able to compute constant values.
|
|
||||||
LogicalResult TryToFold(Operation* op);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Mapping between ValuePort (which corresponds to an OpResult or smaller,
|
// Mapping between ValuePort (which corresponds to an OpResult or smaller,
|
||||||
// e.g., first element of OpResult produded) to an Attribute if the ValuePort
|
// e.g., first element of OpResult produded) to an Attribute if the ValuePort
|
||||||
// corresponds to a constant value.
|
// corresponds to a constant value.
|
||||||
ValuePortResultMap results_;
|
ValuePortResultMap results_;
|
||||||
int64_t graph_version_;
|
int64_t graph_version_;
|
||||||
|
MLIRContext* context_;
|
||||||
Dialect* tf_dialect_;
|
Dialect* tf_dialect_;
|
||||||
};
|
};
|
||||||
|
|
||||||
ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context)
|
ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context)
|
||||||
: graph_version_(graph_version) {
|
: graph_version_(graph_version) {
|
||||||
|
context_ = context;
|
||||||
tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>();
|
tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -602,6 +581,7 @@ ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
|
||||||
auto ret = ComputeOutputComponent(front);
|
auto ret = ComputeOutputComponent(front);
|
||||||
if (!ret) continue;
|
if (!ret) continue;
|
||||||
|
|
||||||
|
RecordValue(front, ret);
|
||||||
LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
|
LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
|
||||||
|
|
||||||
// If worklist is empty, then this is the root query op.
|
// If worklist is empty, then this is the root query op.
|
||||||
|
@ -706,14 +686,10 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||||
size_t index = it.index();
|
size_t index = it.index();
|
||||||
|
|
||||||
// If the operand is constant, then convert it to Tensor.
|
// If the operand is constant, then convert it to Tensor.
|
||||||
ValuePort vp(operand);
|
ElementsAttr attr;
|
||||||
Attribute attr = ComputeOutputComponent(vp);
|
if (matchPattern(operand, m_Constant(&attr))) {
|
||||||
if (!attr && matchPattern(operand, m_Constant(&attr)))
|
|
||||||
RecordValue(vp, attr);
|
|
||||||
if (attr) {
|
|
||||||
tensorflow::Tensor* input_tensor = &tensors[index];
|
tensorflow::Tensor* input_tensor = &tensors[index];
|
||||||
auto status =
|
auto status = tensorflow::ConvertToTensor(attr, input_tensor);
|
||||||
tensorflow::ConvertToTensor(attr.cast<ElementsAttr>(), input_tensor);
|
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
input_tensors[index] = input_tensor;
|
input_tensors[index] = input_tensor;
|
||||||
} else {
|
} else {
|
||||||
|
@ -889,9 +865,13 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
|
||||||
return success(all_succeeded);
|
return success(all_succeeded);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op,
|
// If the callee has only one use, propagates any constant operand of call_op to
|
||||||
SymbolRefAttr callee_sym,
|
// the called function body's corresponding argument.
|
||||||
ModuleOp module) {
|
//
|
||||||
|
// TODO(b/154065712): Move this to a more general inter-procedural constant
|
||||||
|
// folding pass.
|
||||||
|
void PropagateConstantToCallee(CallOpInterface call_op,
|
||||||
|
SymbolRefAttr callee_sym, ModuleOp module) {
|
||||||
auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference());
|
auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference());
|
||||||
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
|
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
|
||||||
int num_uses = std::distance(func_uses->begin(), func_uses->end());
|
int num_uses = std::distance(func_uses->begin(), func_uses->end());
|
||||||
|
@ -899,29 +879,31 @@ void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op,
|
||||||
Operation* op = call_op.getOperation();
|
Operation* op = call_op.getOperation();
|
||||||
if (num_uses == 1) {
|
if (num_uses == 1) {
|
||||||
// If this is the only caller, and an operand is a constant, propagate
|
// If this is the only caller, and an operand is a constant, propagate
|
||||||
// the constant value inside the function.
|
// the constant inside the function.
|
||||||
for (auto arg : func.getArguments()) {
|
for (auto arg : func.getArguments()) {
|
||||||
auto operand = op->getOperand(arg.getArgNumber());
|
auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp();
|
||||||
if (auto known_constant = ComputeOutputComponent(ValuePort(operand)))
|
if (isa_and_nonnull<TF::ConstOp>(operand)) {
|
||||||
RecordValue(ValuePort(arg), known_constant);
|
arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op,
|
// Propagates any constant return value of the callee function to the call op's
|
||||||
SymbolRefAttr callee_sym,
|
// corresponding result.
|
||||||
ModuleOp module) {
|
void PropagateConstantFromCallee(CallOpInterface call_op,
|
||||||
|
SymbolRefAttr callee_sym, ModuleOp module) {
|
||||||
auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference());
|
auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference());
|
||||||
// If the return value is a constant, use the constant as the value of
|
// If the return value is a constant, replace the call result with a constant.
|
||||||
// the call return.
|
|
||||||
Operation* op = call_op.getOperation();
|
Operation* op = call_op.getOperation();
|
||||||
OpBuilder builder(op);
|
OpBuilder builder(op);
|
||||||
builder.setInsertionPointAfter(op);
|
builder.setInsertionPointAfter(op);
|
||||||
for (auto retval :
|
for (auto retval :
|
||||||
llvm::enumerate(func.front().getTerminator()->getOperands())) {
|
llvm::enumerate(func.front().getTerminator()->getOperands())) {
|
||||||
ValuePort vp(retval.value());
|
auto retval_op = retval.value().getDefiningOp();
|
||||||
if (auto known_constant = ComputeOutputComponent(vp)) {
|
if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
|
||||||
RecordValue(ValuePort(op->getResult(retval.index())), known_constant);
|
op->getResult(retval.index())
|
||||||
|
.replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -956,68 +938,10 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ShapeInference::TryToFold(Operation* op) {
|
|
||||||
// If any output result is known, then the op probably has been computed
|
|
||||||
// before.
|
|
||||||
if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))])
|
|
||||||
return success();
|
|
||||||
|
|
||||||
SmallVector<Attribute, 8> constant_operands(op->getNumOperands());
|
|
||||||
SmallVector<OpFoldResult, 8> fold_results;
|
|
||||||
|
|
||||||
// Check to see if any operands to the operation is constant and whether
|
|
||||||
// the operation knows how to constant fold itself.
|
|
||||||
bool some_unknown = false;
|
|
||||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
|
||||||
if (!(constant_operands[i] =
|
|
||||||
ComputeOutputComponent(ValuePort(op->getOperand(i)))))
|
|
||||||
some_unknown = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to constant fold the operation.
|
|
||||||
auto* abstract_op = op->getAbstractOperation();
|
|
||||||
if (abstract_op) {
|
|
||||||
if (failed(abstract_op->foldHook(op, constant_operands, fold_results)))
|
|
||||||
return failure();
|
|
||||||
} else {
|
|
||||||
Dialect* dialect = op->getDialect();
|
|
||||||
if (!dialect) return failure();
|
|
||||||
// Only attempt TF dialect fallback if there are no unknown operands.
|
|
||||||
if (some_unknown && dialect == tf_dialect_) return failure();
|
|
||||||
SmallVector<Attribute, 8> constants;
|
|
||||||
if (failed(dialect->constantFoldHook(op, constant_operands, constants)))
|
|
||||||
return failure();
|
|
||||||
fold_results.assign(constants.begin(), constants.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto result : zip(op->getResults(), fold_results)) {
|
|
||||||
auto fold_result = std::get<1>(result);
|
|
||||||
Attribute attr = nullptr;
|
|
||||||
if ((attr = fold_result.dyn_cast<Attribute>())) {
|
|
||||||
RecordValue(ValuePort(std::get<0>(result)), attr);
|
|
||||||
} else {
|
|
||||||
auto value = fold_result.get<Value>();
|
|
||||||
if ((attr = ComputeOutputComponent(ValuePort(value))))
|
|
||||||
RecordValue(ValuePort(std::get<0>(result)), attr);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) {
|
|
||||||
if (std::get<0>(result).getType() == eattr.getType()) continue;
|
|
||||||
|
|
||||||
// Inserts a cast back to the original type if any user is not in the
|
|
||||||
// TF dialect.
|
|
||||||
Type old_type = std::get<0>(result).getType();
|
|
||||||
std::get<0>(result).setType(eattr.getType());
|
|
||||||
AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), tf_dialect_,
|
|
||||||
old_type);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
|
LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
|
||||||
int64_t max_iteration) {
|
int64_t max_iteration) {
|
||||||
|
// An operation folder that is used to attempt folding before inference._
|
||||||
|
OperationFolder folder(context_);
|
||||||
bool changed = true;
|
bool changed = true;
|
||||||
|
|
||||||
// TODO(aminim): we could have a more efficient traversal by guiding the
|
// TODO(aminim): we could have a more efficient traversal by guiding the
|
||||||
|
@ -1031,7 +955,9 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
|
||||||
region->walk([&](Operation* op) {
|
region->walk([&](Operation* op) {
|
||||||
if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
|
if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
|
||||||
changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_);
|
changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_);
|
||||||
return;
|
// TODO(jpienaar): Debug why we can't just return here. We end up with
|
||||||
|
// additional constant due to the propagation of constant into attached
|
||||||
|
// function if we return already.
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op->getDialect() != tf_dialect_) {
|
if (op->getDialect() != tf_dialect_) {
|
||||||
|
@ -1039,9 +965,8 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Before attempting inference, just try to compute the folded
|
// Before attempting inference, just try to fold the operation.
|
||||||
// value/shape.
|
if (succeeded(folder.tryToFold(op))) return;
|
||||||
if (succeeded(TryToFold(op))) return;
|
|
||||||
|
|
||||||
// Best-effort shape inference in attached functions. Do not return
|
// Best-effort shape inference in attached functions. Do not return
|
||||||
// failure even if it doesn't get to fixed point.
|
// failure even if it doesn't get to fixed point.
|
||||||
|
|
Loading…
Reference in New Issue