Refactor MLIR TF shape inference to have a context

This enables reusing the partial results computed/caching the query results (ValuePortResultMap). This also reduce some arguments being passed around (else in the follow up I'd need to pass a context everywhere). Should be NFC change.

PiperOrigin-RevId: 311166241
Change-Id: Icb6ea66c6c16a06d4bc9077225f1d7a783548dca
This commit is contained in:
Jacques Pienaar 2020-05-12 11:24:53 -07:00 committed by TensorFlower Gardener
parent 88acf9fcc5
commit a2afd0e358

View File

@ -66,8 +66,7 @@ using tensorflow::shape_inference::ShapeHandle;
namespace mlir {
namespace TF {
namespace {
Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType(
FuncOp func) {
Optional<SmallVector<Type, 4>> InferShapeForFunctionReturnType(FuncOp func) {
// Find any return ops.
SmallVector<ReturnOp, 4> return_ops;
for (Block& block : func) {
@ -137,9 +136,9 @@ void AddCastBackForUnsupportedNonTFUses(Operation* op, Value result,
cast_op = b.create<TF::CastOp>(op->getLoc(), old_type, result,
/*truncate=*/b.getBoolAttr(false));
}
return mlir::Value(cast_op);
return Value(cast_op);
};
for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) {
for (OpOperand& use : make_early_inc_range(result.getUses())) {
if (use.getOwner()->getDialect() != tf_dialect &&
!IsSupportedNonTFOp(use.getOwner()))
use.set(get_cast_op());
@ -162,7 +161,7 @@ Optional<tensorflow::PartialTensorShape> GetShapeFromMlirType(Type t) {
bool InferShapeForPassThroughOps(OperandRange pass_through_operands,
Operation* op, Dialect* tf_dialect) {
bool changed = false;
for (auto entry : llvm::zip(pass_through_operands, op->getResults())) {
for (auto entry : zip(pass_through_operands, op->getResults())) {
Type operand_type = std::get<0>(entry).getType();
Value result = std::get<1>(entry);
if (result.getType() == operand_type) continue;
@ -204,7 +203,7 @@ bool InferShapeForNonTFDialectOperation(Operation* op, Dialect* tf_dialect) {
tf_dialect);
}
// TODO(b/155227679): Use OpInterface instead of hard-coding for TensorCastOp.
if (auto tensor_cast = dyn_cast<mlir::TensorCastOp>(op)) {
if (auto tensor_cast = dyn_cast<TensorCastOp>(op)) {
return InferShapeForPassThroughOps(
tensor_cast.getOperation()->getOperands(), op, tf_dialect);
}
@ -254,7 +253,7 @@ GetSubtypes(Type type) {
// match the i-th operand type). Returns true if anything is changed.
bool PassThroughOperandTypes(OperandRange operands, ResultRange results) {
bool changed = false;
for (auto entry : llvm::zip(operands, results)) {
for (auto entry : zip(operands, results)) {
Type operand_type = std::get<0>(entry).getType();
Type result_type = std::get<1>(entry).getType();
if (operand_type == result_type) continue;
@ -291,14 +290,13 @@ bool InferShapeForCall(Operation* op) {
CallInterfaceCallable callable = call_op.getCallableForCallee();
SymbolRefAttr sym = callable.dyn_cast<SymbolRefAttr>();
if (!sym) return false;
FuncOp func =
dyn_cast<mlir::FuncOp>(SymbolTable::lookupNearestSymbolFrom(op, sym));
FuncOp func = dyn_cast<FuncOp>(SymbolTable::lookupNearestSymbolFrom(op, sym));
if (!func) return false;
bool changed = false;
// Map each of the results of the call to the returned type of the
// function.
for (auto result : llvm::zip(op->getResults(), func.getType().getResults())) {
for (auto result : zip(op->getResults(), func.getType().getResults())) {
if (std::get<0>(result).getType() == std::get<1>(result)) continue;
// Skip already statically shaped results.
if (!CanBeRefined(std::get<0>(result).getType())) continue;
@ -335,7 +333,7 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti,
// Map each of the results of the call to the returned type of the
// function.
bool changed = false;
for (auto result : llvm::zip(op->getResults(), inferred)) {
for (auto result : zip(op->getResults(), inferred)) {
if (std::get<0>(result).getType() == std::get<1>(result)) continue;
// Inserts a cast back to the original type if any user is not in the
@ -356,7 +354,7 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti,
// so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output
// scalar value).
struct ValuePort {
llvm::PointerUnion<Operation*, BlockArgument> producer;
PointerUnion<Operation*, BlockArgument> producer;
SmallVector<unsigned int, 2> port;
bool operator==(const ValuePort& other) const {
@ -374,39 +372,38 @@ struct ValuePort {
port = {0};
}
}
ValuePort(llvm::PointerUnion<Operation*, BlockArgument> producer,
ValuePort(PointerUnion<Operation*, BlockArgument> producer,
SmallVector<unsigned int, 2> port)
: producer(producer), port(port) {}
llvm::raw_ostream& print(llvm::raw_ostream& os) const {
raw_ostream& print(raw_ostream& os) const {
if (auto* op = producer.dyn_cast<Operation*>())
os << "op " << op->getName();
if (auto ba = producer.dyn_cast<BlockArgument>())
os << "block_arg " << ba.getArgNumber();
os << llvm::formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
return os;
}
};
struct ValuePortHasher {
std::size_t operator()(const ValuePort& other) const {
return llvm::hash_combine(
llvm::hash_value(other.producer.getOpaqueValue()),
llvm::hash_value(ArrayRef<unsigned int>(other.port)));
return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()),
hash_value(ArrayRef<unsigned int>(other.port)));
}
};
using ValuePortResultMap =
std::unordered_map<ValuePort, Attribute, ValuePortHasher>;
using ComputedQueryFn = llvm::function_ref<bool(ValuePort)>;
using ValueQueryFn = llvm::function_ref<Attribute(const ValuePort&)>;
using ValuePortInputs = llvm::SmallVectorImpl<ValuePort>;
using ComputedQueryFn = function_ref<bool(ValuePort)>;
using ValueQueryFn = function_ref<Attribute(const ValuePort&)>;
using ValuePortInputs = SmallVectorImpl<ValuePort>;
// TODO(jpienaar): InputsRequiredForOutput and ComputeOutputComponent are
// TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are
// intended to be switched to op interfaces once more refined.
LogicalResult InputsRequiredForOutput(ValuePort value_port,
ComputedQueryFn has_been_computed,
ValuePortInputs* inputs) {
LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
ComputedQueryFn has_been_computed,
ValuePortInputs* inputs) {
auto op = value_port.producer.dyn_cast<Operation*>();
auto& port = value_port.port;
if (!op) return failure();
@ -460,26 +457,94 @@ Attribute ComputeOutputComponent(const ValuePort& value_port,
return nullptr;
}
ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
// Context used during ShapeInference. This class contains common information
// that is required by the individual shape inference helper functions (e.g.,
// TF Graph version, constant values computed, etc.)
class ShapeInference {
public:
ShapeInference(int64_t graph_version, MLIRContext* context);
LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
ValuePortInputs* inputs) {
return ::mlir::TF::ComputeInputsRequiredForOutput(
value_port,
[this](const ValuePort& port) {
return results_.find(port) != results_.end();
},
inputs);
}
Attribute ComputeOutputComponent(const ValuePort& value_port) {
return ::mlir::TF::ComputeOutputComponent(
value_port, [this](const ValuePort& port) { return results_[port]; });
}
// Returns ShapeHandle if the op result could be computed as shape.
ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic);
void RecordValue(const ValuePort& value_port, Attribute value) {
results_[value_port] = value;
}
// Performs shape inference on the provided op and return true if the type of
// at least one result has been changed.
// A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
// `graph_version` indicates the current GraphDef compatibility versions
// (the versions field in graph.proto).
bool InferShapeForSingleOperation(Operation* op);
// Infers shape on the provided region, including nested ones, iterate until
// fix point with a limit of max_iteration. Returns success if fix point is
// reached before max_iteration.
LogicalResult InferShapeUntilFixPoint(Region* region,
int64_t max_iteration = 10);
// Updates input types and refine shapes inside body of functions that are
// attached to ControlFlow ops (If/While). These functions include Then/Else
// branches of IfOp and Cond/Body functions of WhileOp. These functions share
// following common properties:
// 1) They are never reused, ie. having a single use in module.
// 2) Their input types match those of their parent ops (excluding inputs
// like predicate).
// Returns a boolean indicating whether any change has been applied.
LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
ArrayRef<Type> input_types,
int64_t max_iteration);
// Propagate the shapes to the functions named.
LogicalResult PropagateShapeToFunctions(
ModuleOp module, Operation::operand_type_range input_types,
ArrayRef<StringRef> func_names, int64_t max_iteration);
// Shape propagation for call/control flow ops.
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
int64_t max_iteration);
private:
// Mapping between ValuePort (which corresponds to an OpResult or smaller,
// e.g., first element of OpResult produded) to an Attribute if the ValuePort
// corresponds to a constant value.
ValuePortResultMap results_;
int64_t graph_version_;
MLIRContext* context_;
Dialect* tf_dialect_;
};
ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context)
: graph_version_(graph_version) {
context_ = context;
tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>();
}
ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
InferenceContext* ic) {
LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially "));
auto rt = result.getType().dyn_cast<RankedTensorType>();
if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {};
int dim_size = rt.getDimSize(0);
// Worklist to direct partial evaluation.
llvm::SmallVector<ValuePort, 4> worklist;
// The ValuePort evaluated results.
// TODO(jpienaar): This could be cached across invocations (e.g., part of some
// inference context).
ValuePortResultMap evaluated;
// Returns whether a ValuePort has been previously computed.
auto has_been_computed = [&evaluated](const ValuePort& port) {
return evaluated.find(port) != evaluated.end();
};
// Returns previously computed ValuePort value.
auto values = [&evaluated](const ValuePort& port) -> Attribute {
return evaluated[port];
};
SmallVector<ValuePort, 4> worklist;
// Simple evaluator that attempts to partially evaluate the input value even
// if unable to evaluate the complete output. Below follows a simple stack
@ -498,7 +563,7 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
LLVM_DEBUG(front.print(llvm::errs() << "\nWorklist front "));
SmallVector<ValuePort, 4> inputs;
auto res = InputsRequiredForOutput(front, has_been_computed, &inputs);
auto res = ComputeInputsRequiredForOutput(front, &inputs);
if (failed(res)) {
// Abort if unable to find which required inputs need to be computed.
worklist.clear();
@ -513,16 +578,16 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
continue;
}
auto ret = ComputeOutputComponent(front, values);
auto ret = ComputeOutputComponent(front);
if (!ret) continue;
evaluated[front] = ret;
RecordValue(front, ret);
LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
// If worklist is empty, then this is the root query op.
if (worklist.empty()) {
LLVM_DEBUG(llvm::dbgs() << "[root node]\n");
if (auto dea = ret.dyn_cast<mlir::DenseIntElementsAttr>()) {
if (auto dea = ret.dyn_cast<DenseIntElementsAttr>()) {
if (dea.getNumElements() != 1) {
LLVM_DEBUG(llvm::errs() << "Unexpected number of elements\n");
return {};
@ -536,14 +601,8 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
return ic->MakeShape(dims);
}
// Performs shape inference on the provided op and return true if the type of
// at least one result has been changed.
// A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
// `graph_version` indicates the current GraphDef compatibility versions
// (the versions field in graph.proto).
bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
int64_t graph_version) {
assert(tf_dialect == op->getDialect());
bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
assert(tf_dialect_ == op->getDialect());
// The shape function of these ops sometimes does not propagate subtypes
// (handle shapes) for resource and variant types. We use a simple passthrough
// to make sure they are preserved in the output.
@ -555,7 +614,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
// If no result for this op needs shape inference, we have a fast-path return.
// But if the type is a resource/variant, we do not skip it because we might
// not have the handle shapes.
if (llvm::none_of(op->getResultTypes(), CanBeRefined)) {
if (none_of(op->getResultTypes(), CanBeRefined)) {
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
<< op->getName() << "'.\n");
return false;
@ -570,8 +629,8 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
// This is necessary to avoid reprocessing the tf.Cast that are inserted at
// the end of this function.
if (isa<CastOp>(op) &&
llvm::all_of(op->getResult(0).getUsers(), [&](Operation* user) {
return user->getDialect() != tf_dialect;
all_of(op->getResult(0).getUsers(), [&](Operation* user) {
return user->getDialect() != tf_dialect_;
})) {
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF "
"dialect operation users '"
@ -651,7 +710,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
// Perform the shape inference using an InferenceContext with the input
// shapes. This object is abstracting the information that the ShapeInference
// function operates on.
InferenceContext c(graph_version, *node_def, op_reg_data->op_def,
InferenceContext c(graph_version_, *node_def, op_reg_data->op_def,
input_shapes, input_tensors,
/*input_tensors_as_shapes=*/{}, handle_shapes_and_types);
auto status = c.Run(op_reg_data->shape_inference_fn);
@ -664,7 +723,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
// Determine if, during shape computation, the shape functions attempted to
// query an input operand as shape where the input was not known/constant.
bool requires_inputs =
llvm::any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) {
any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) {
return c.requested_input_tensor_as_partial_shape(input) &&
!input_tensors[input];
});
@ -728,7 +787,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
new_element_type.isa<TF::VariantType>()) {
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
if (handle_shapes_types) {
llvm::SmallVector<mlir::TensorType, 1> subtypes;
SmallVector<TensorType, 1> subtypes;
OpBuilder b(op);
for (const auto& shape_n_type : *handle_shapes_types) {
Type element_type;
@ -748,7 +807,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
if (result.getType() == new_type) continue;
// Inserts a cast back to the original type if any user is not in the TF
// dialect.
AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect,
AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect_,
result.getType());
// Finally we inferred the shape and replace the type for this result.
result.setType(new_type);
@ -760,29 +819,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
return changed;
}
// Infers shape on the provided region, including nested ones, iterate until fix
// point with a limit of max_iteration. Returns success if fix point is reached
// before max_iteration.
LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
int64_t max_iteration = 10);
// Updates input types and refine shapes inside body of functions that are
// attached to ControlFlow ops (If/While). These functions include Then/Else
// branches of IfOp and Cond/Body functions of WhileOp. These functions share
// following common properties:
// 1) They are never reused, ie. having a single use in module.
// 2) Their input types match those of their parent ops (excluding inputs like
// predicate).
// Returns a boolean indicating whether any change has been applied.
LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
llvm::ArrayRef<Type> input_types,
int64_t graph_version,
int64_t max_iteration) {
LogicalResult ShapeInference::RefineShapeForControlFlowFunc(
FuncOp func, ArrayRef<Type> input_types, int64_t max_iteration) {
ModuleOp module = func.getParentOfType<ModuleOp>();
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
int num_uses = std::distance(func_uses->begin(), func_uses->end());
if (num_uses != 1) {
func.emitWarning(llvm::formatv(
func.emitWarning(formatv(
"expected control flow function {0} to have exactly 1 use, found {1}.",
func.getName(), num_uses));
return failure();
@ -796,8 +839,7 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
arg_and_idx.value().setType(input_types[arg_and_idx.index()]);
}
auto res =
InferShapeUntilFixPoint(&func.getBody(), graph_version, max_iteration);
auto res = InferShapeUntilFixPoint(&func.getBody(), max_iteration);
if (failed(res)) return res;
auto new_return_types = InferShapeForFunctionReturnType(func);
@ -809,20 +851,18 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
return success();
}
LogicalResult PropagateShapeToFunctions(
LogicalResult ShapeInference::PropagateShapeToFunctions(
ModuleOp module, Operation::operand_type_range input_types,
llvm::ArrayRef<StringRef> func_names, int64_t graph_version,
int64_t max_iteration) {
bool success = true;
ArrayRef<StringRef> func_names, int64_t max_iteration) {
bool all_succeeded = true;
auto types = llvm::to_vector<4>(input_types);
for (auto func_name : func_names) {
FuncOp func = module.lookupSymbol<FuncOp>(func_name);
if (failed(RefineShapeForControlFlowFunc(func, types, graph_version,
max_iteration))) {
success = false;
}
all_succeeded =
succeeded(RefineShapeForControlFlowFunc(func, types, max_iteration)) &&
all_succeeded;
}
return mlir::success(success);
return success(all_succeeded);
}
// If the callee has only one use, propagates any constant operand of call_op to
@ -842,7 +882,7 @@ void PropagateConstantToCallee(CallOpInterface call_op,
// the constant inside the function.
for (auto arg : func.getArguments()) {
auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp();
if (llvm::isa_and_nonnull<TF::ConstOp>(operand)) {
if (isa_and_nonnull<TF::ConstOp>(operand)) {
arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0));
}
}
@ -861,33 +901,31 @@ void PropagateConstantFromCallee(CallOpInterface call_op,
for (auto retval :
llvm::enumerate(func.front().getTerminator()->getOperands())) {
auto retval_op = retval.value().getDefiningOp();
if (llvm::isa_and_nonnull<TF::ConstOp>(retval_op)) {
if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
op->getResult(retval.index())
.replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
}
}
}
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
int64_t graph_version,
int64_t max_iteration) {
LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
Operation* op, int64_t max_iteration) {
ModuleOp module = op->getParentOfType<ModuleOp>();
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
return PropagateShapeToFunctions(
module, llvm::drop_begin(if_op.getOperandTypes(), 1),
{if_op.then_branch(), if_op.else_branch()}, graph_version,
max_iteration);
module, drop_begin(if_op.getOperandTypes(), 1),
{if_op.then_branch(), if_op.else_branch()}, max_iteration);
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
return PropagateShapeToFunctions(module, while_op.getOperandTypes(),
{while_op.cond(), while_op.body()},
graph_version, max_iteration);
max_iteration);
} else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
CallInterfaceCallable callable = call_op.getCallableForCallee();
if (SymbolRefAttr sym = callable.dyn_cast<SymbolRefAttr>()) {
PropagateConstantToCallee(call_op, sym, module);
if (failed(PropagateShapeToFunctions(
module, call_op.getArgOperands().getTypes(),
{sym.getRootReference()}, graph_version, max_iteration))) {
{sym.getRootReference()}, max_iteration))) {
return failure();
}
PropagateConstantFromCallee(call_op, sym, module);
@ -900,13 +938,10 @@ LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
return success();
}
LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
int64_t max_iteration) {
MLIRContext* ctx = region->getContext();
Dialect* tf_dialect = ctx->getRegisteredDialect<TensorFlowDialect>();
// An operation folder that is used to attempt folding before inference.
OperationFolder folder(ctx);
LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
int64_t max_iteration) {
// An operation folder that is used to attempt folding before inference._
OperationFolder folder(context_);
bool changed = true;
// TODO(aminim): we could have a more efficient traversal by guiding the
@ -919,14 +954,14 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
<< "Shape inference, iteration " << iteration << "\n");
region->walk([&](Operation* op) {
if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect);
changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_);
// TODO(jpienaar): Debug why we can't just return here. We end up with
// additional constant due to the propagation of constant into attached
// function if we return already.
}
if (op->getDialect() != tf_dialect) {
changed |= InferShapeForNonTFDialectOperation(op, tf_dialect);
if (op->getDialect() != tf_dialect_) {
changed |= InferShapeForNonTFDialectOperation(op, tf_dialect_);
return;
}
@ -935,13 +970,12 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
// Best-effort shape inference in attached functions. Do not return
// failure even if it doesn't get to fixed point.
if (failed(PropagateShapeIntoAttachedFunctions(op, graph_version,
max_iteration))) {
if (failed(PropagateShapeIntoAttachedFunctions(op, max_iteration))) {
op->emitWarning() << "unable to refine shape of attached function "
"arguments and bodies";
}
changed |= InferShapeForSingleOperation(op, tf_dialect, graph_version);
changed |= InferShapeForSingleOperation(op);
});
}
@ -956,44 +990,43 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
LogicalResult InferShapeForFunction(FuncOp func,
ArrayRef<ArrayRef<int64_t>> arg_shapes,
int64_t graph_version) {
ShapeInference context(graph_version, func.getContext());
if (arg_shapes.empty()) {
if (failed(InferShapeUntilFixPoint(&func.getBody(), graph_version)))
if (failed(context.InferShapeUntilFixPoint(&func.getBody())))
return failure();
// TODO(b/156276510): Verify that it is always fine to refine a function's
// return type, as long as we do not change the argument shapes.
if (auto return_types = InferShapeForFunctionReturnType(func)) {
func.setType(mlir::FunctionType::get(func.getType().getInputs(),
return_types.getValue(),
func.getContext()));
func.setType(FunctionType::get(func.getType().getInputs(),
return_types.getValue(),
func.getContext()));
}
return success();
}
mlir::FunctionType func_type = func.getType();
FunctionType func_type = func.getType();
bool needs_refinement = false;
llvm::SmallVector<mlir::Type, 4> new_arg_types;
SmallVector<Type, 4> new_arg_types;
new_arg_types.reserve(func_type.getNumInputs());
// Update argument types in-place using the provided arg_shapes.
for (size_t i = 0; i < func_type.getNumInputs(); ++i) {
ArrayRef<int64_t> shape = arg_shapes[i];
mlir::Type element_type;
if (auto input_ty =
func_type.getInput(i).dyn_cast<mlir::RankedTensorType>()) {
Type element_type;
if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
if (!input_ty || input_ty.getShape().size() != shape.size()) {
return failure();
}
element_type = input_ty.getElementType();
} else {
auto unranked_input_ty =
func_type.getInput(i).dyn_cast<mlir::TensorType>();
auto unranked_input_ty = func_type.getInput(i).dyn_cast<TensorType>();
if (!unranked_input_ty) {
return failure();
}
element_type = unranked_input_ty.getElementType();
}
auto new_arg_type = mlir::RankedTensorType::get(shape, element_type);
auto new_arg_type = RankedTensorType::get(shape, element_type);
if (new_arg_type != func_type.getInput(i)) {
// If the new type is more detailed, trigger shape inference.
func.getArgument(i).setType(new_arg_type);
@ -1006,18 +1039,17 @@ LogicalResult InferShapeForFunction(FuncOp func,
return success();
}
mlir::LogicalResult result =
mlir::TF::InferShapeUntilFixPoint(&func.getBody(), graph_version);
LogicalResult result = context.InferShapeUntilFixPoint(&func.getBody());
if (failed(result)) {
return failure();
}
auto return_types = InferShapeForFunctionReturnType(func);
func.setType(mlir::FunctionType::get(new_arg_types,
return_types.hasValue()
? return_types.getValue()
: func.getType().getResults(),
func.getContext()));
func.setType(FunctionType::get(new_arg_types,
return_types.hasValue()
? return_types.getValue()
: func.getType().getResults(),
func.getContext()));
return success();
}