Refactor MLIR TF shape inference to have a context
This enables reusing the partial results computed/caching the query results (ValuePortResultMap). This also reduce some arguments being passed around (else in the follow up I'd need to pass a context everywhere). Should be NFC change. PiperOrigin-RevId: 311166241 Change-Id: Icb6ea66c6c16a06d4bc9077225f1d7a783548dca
This commit is contained in:
		
							parent
							
								
									88acf9fcc5
								
							
						
					
					
						commit
						a2afd0e358
					
				| @ -66,8 +66,7 @@ using tensorflow::shape_inference::ShapeHandle; | ||||
| namespace mlir { | ||||
| namespace TF { | ||||
| namespace { | ||||
| Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType( | ||||
|     FuncOp func) { | ||||
| Optional<SmallVector<Type, 4>> InferShapeForFunctionReturnType(FuncOp func) { | ||||
|   // Find any return ops.
 | ||||
|   SmallVector<ReturnOp, 4> return_ops; | ||||
|   for (Block& block : func) { | ||||
| @ -137,9 +136,9 @@ void AddCastBackForUnsupportedNonTFUses(Operation* op, Value result, | ||||
|       cast_op = b.create<TF::CastOp>(op->getLoc(), old_type, result, | ||||
|                                      /*truncate=*/b.getBoolAttr(false)); | ||||
|     } | ||||
|     return mlir::Value(cast_op); | ||||
|     return Value(cast_op); | ||||
|   }; | ||||
|   for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) { | ||||
|   for (OpOperand& use : make_early_inc_range(result.getUses())) { | ||||
|     if (use.getOwner()->getDialect() != tf_dialect && | ||||
|         !IsSupportedNonTFOp(use.getOwner())) | ||||
|       use.set(get_cast_op()); | ||||
| @ -162,7 +161,7 @@ Optional<tensorflow::PartialTensorShape> GetShapeFromMlirType(Type t) { | ||||
| bool InferShapeForPassThroughOps(OperandRange pass_through_operands, | ||||
|                                  Operation* op, Dialect* tf_dialect) { | ||||
|   bool changed = false; | ||||
|   for (auto entry : llvm::zip(pass_through_operands, op->getResults())) { | ||||
|   for (auto entry : zip(pass_through_operands, op->getResults())) { | ||||
|     Type operand_type = std::get<0>(entry).getType(); | ||||
|     Value result = std::get<1>(entry); | ||||
|     if (result.getType() == operand_type) continue; | ||||
| @ -204,7 +203,7 @@ bool InferShapeForNonTFDialectOperation(Operation* op, Dialect* tf_dialect) { | ||||
|         tf_dialect); | ||||
|   } | ||||
|   // TODO(b/155227679): Use OpInterface instead of hard-coding for TensorCastOp.
 | ||||
|   if (auto tensor_cast = dyn_cast<mlir::TensorCastOp>(op)) { | ||||
|   if (auto tensor_cast = dyn_cast<TensorCastOp>(op)) { | ||||
|     return InferShapeForPassThroughOps( | ||||
|         tensor_cast.getOperation()->getOperands(), op, tf_dialect); | ||||
|   } | ||||
| @ -254,7 +253,7 @@ GetSubtypes(Type type) { | ||||
| // match the i-th operand type). Returns true if anything is changed.
 | ||||
| bool PassThroughOperandTypes(OperandRange operands, ResultRange results) { | ||||
|   bool changed = false; | ||||
|   for (auto entry : llvm::zip(operands, results)) { | ||||
|   for (auto entry : zip(operands, results)) { | ||||
|     Type operand_type = std::get<0>(entry).getType(); | ||||
|     Type result_type = std::get<1>(entry).getType(); | ||||
|     if (operand_type == result_type) continue; | ||||
| @ -291,14 +290,13 @@ bool InferShapeForCall(Operation* op) { | ||||
|   CallInterfaceCallable callable = call_op.getCallableForCallee(); | ||||
|   SymbolRefAttr sym = callable.dyn_cast<SymbolRefAttr>(); | ||||
|   if (!sym) return false; | ||||
|   FuncOp func = | ||||
|       dyn_cast<mlir::FuncOp>(SymbolTable::lookupNearestSymbolFrom(op, sym)); | ||||
|   FuncOp func = dyn_cast<FuncOp>(SymbolTable::lookupNearestSymbolFrom(op, sym)); | ||||
|   if (!func) return false; | ||||
| 
 | ||||
|   bool changed = false; | ||||
|   // Map each of the results of the call to the returned type of the
 | ||||
|   // function.
 | ||||
|   for (auto result : llvm::zip(op->getResults(), func.getType().getResults())) { | ||||
|   for (auto result : zip(op->getResults(), func.getType().getResults())) { | ||||
|     if (std::get<0>(result).getType() == std::get<1>(result)) continue; | ||||
|     // Skip already statically shaped results.
 | ||||
|     if (!CanBeRefined(std::get<0>(result).getType())) continue; | ||||
| @ -335,7 +333,7 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti, | ||||
|   // Map each of the results of the call to the returned type of the
 | ||||
|   // function.
 | ||||
|   bool changed = false; | ||||
|   for (auto result : llvm::zip(op->getResults(), inferred)) { | ||||
|   for (auto result : zip(op->getResults(), inferred)) { | ||||
|     if (std::get<0>(result).getType() == std::get<1>(result)) continue; | ||||
| 
 | ||||
|     // Inserts a cast back to the original type if any user is not in the
 | ||||
| @ -356,7 +354,7 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti, | ||||
| // so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output
 | ||||
| // scalar value).
 | ||||
| struct ValuePort { | ||||
|   llvm::PointerUnion<Operation*, BlockArgument> producer; | ||||
|   PointerUnion<Operation*, BlockArgument> producer; | ||||
|   SmallVector<unsigned int, 2> port; | ||||
| 
 | ||||
|   bool operator==(const ValuePort& other) const { | ||||
| @ -374,39 +372,38 @@ struct ValuePort { | ||||
|       port = {0}; | ||||
|     } | ||||
|   } | ||||
|   ValuePort(llvm::PointerUnion<Operation*, BlockArgument> producer, | ||||
|   ValuePort(PointerUnion<Operation*, BlockArgument> producer, | ||||
|             SmallVector<unsigned int, 2> port) | ||||
|       : producer(producer), port(port) {} | ||||
| 
 | ||||
|   llvm::raw_ostream& print(llvm::raw_ostream& os) const { | ||||
|   raw_ostream& print(raw_ostream& os) const { | ||||
|     if (auto* op = producer.dyn_cast<Operation*>()) | ||||
|       os << "op " << op->getName(); | ||||
|     if (auto ba = producer.dyn_cast<BlockArgument>()) | ||||
|       os << "block_arg " << ba.getArgNumber(); | ||||
|     os << llvm::formatv(" [{0}]", llvm::make_range(port.begin(), port.end())); | ||||
|     os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end())); | ||||
|     return os; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct ValuePortHasher { | ||||
|   std::size_t operator()(const ValuePort& other) const { | ||||
|     return llvm::hash_combine( | ||||
|         llvm::hash_value(other.producer.getOpaqueValue()), | ||||
|         llvm::hash_value(ArrayRef<unsigned int>(other.port))); | ||||
|     return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()), | ||||
|                         hash_value(ArrayRef<unsigned int>(other.port))); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| using ValuePortResultMap = | ||||
|     std::unordered_map<ValuePort, Attribute, ValuePortHasher>; | ||||
| using ComputedQueryFn = llvm::function_ref<bool(ValuePort)>; | ||||
| using ValueQueryFn = llvm::function_ref<Attribute(const ValuePort&)>; | ||||
| using ValuePortInputs = llvm::SmallVectorImpl<ValuePort>; | ||||
| using ComputedQueryFn = function_ref<bool(ValuePort)>; | ||||
| using ValueQueryFn = function_ref<Attribute(const ValuePort&)>; | ||||
| using ValuePortInputs = SmallVectorImpl<ValuePort>; | ||||
| 
 | ||||
| // TODO(jpienaar): InputsRequiredForOutput and ComputeOutputComponent are
 | ||||
| // TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are
 | ||||
| // intended to be switched to op interfaces once more refined.
 | ||||
| LogicalResult InputsRequiredForOutput(ValuePort value_port, | ||||
|                                       ComputedQueryFn has_been_computed, | ||||
|                                       ValuePortInputs* inputs) { | ||||
| LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, | ||||
|                                              ComputedQueryFn has_been_computed, | ||||
|                                              ValuePortInputs* inputs) { | ||||
|   auto op = value_port.producer.dyn_cast<Operation*>(); | ||||
|   auto& port = value_port.port; | ||||
|   if (!op) return failure(); | ||||
| @ -460,26 +457,94 @@ Attribute ComputeOutputComponent(const ValuePort& value_port, | ||||
|   return nullptr; | ||||
| } | ||||
| 
 | ||||
| ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) { | ||||
| // Context used during ShapeInference. This class contains common information
 | ||||
| // that is required by the individual shape inference helper functions (e.g.,
 | ||||
| // TF Graph version, constant values computed, etc.)
 | ||||
| class ShapeInference { | ||||
|  public: | ||||
|   ShapeInference(int64_t graph_version, MLIRContext* context); | ||||
| 
 | ||||
|   LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, | ||||
|                                                ValuePortInputs* inputs) { | ||||
|     return ::mlir::TF::ComputeInputsRequiredForOutput( | ||||
|         value_port, | ||||
|         [this](const ValuePort& port) { | ||||
|           return results_.find(port) != results_.end(); | ||||
|         }, | ||||
|         inputs); | ||||
|   } | ||||
| 
 | ||||
|   Attribute ComputeOutputComponent(const ValuePort& value_port) { | ||||
|     return ::mlir::TF::ComputeOutputComponent( | ||||
|         value_port, [this](const ValuePort& port) { return results_[port]; }); | ||||
|   } | ||||
| 
 | ||||
|   // Returns ShapeHandle if the op result could be computed as shape.
 | ||||
|   ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic); | ||||
| 
 | ||||
|   void RecordValue(const ValuePort& value_port, Attribute value) { | ||||
|     results_[value_port] = value; | ||||
|   } | ||||
| 
 | ||||
|   // Performs shape inference on the provided op and return true if the type of
 | ||||
|   // at least one result has been changed.
 | ||||
|   // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
 | ||||
|   // `graph_version` indicates the current GraphDef compatibility versions
 | ||||
|   // (the versions field in graph.proto).
 | ||||
|   bool InferShapeForSingleOperation(Operation* op); | ||||
| 
 | ||||
|   // Infers shape on the provided region, including nested ones, iterate until
 | ||||
|   // fix point with a limit of max_iteration. Returns success if fix point is
 | ||||
|   // reached before max_iteration.
 | ||||
|   LogicalResult InferShapeUntilFixPoint(Region* region, | ||||
|                                         int64_t max_iteration = 10); | ||||
| 
 | ||||
|   // Updates input types and refine shapes inside body of functions that are
 | ||||
|   // attached to ControlFlow ops (If/While). These functions include Then/Else
 | ||||
|   // branches of IfOp and Cond/Body functions of WhileOp. These functions share
 | ||||
|   // following common properties:
 | ||||
|   //   1) They are never reused, ie. having a single use in module.
 | ||||
|   //   2) Their input types match those of their parent ops (excluding inputs
 | ||||
|   //      like predicate).
 | ||||
|   // Returns a boolean indicating whether any change has been applied.
 | ||||
|   LogicalResult RefineShapeForControlFlowFunc(FuncOp func, | ||||
|                                               ArrayRef<Type> input_types, | ||||
|                                               int64_t max_iteration); | ||||
| 
 | ||||
|   // Propagate the shapes to the functions named.
 | ||||
|   LogicalResult PropagateShapeToFunctions( | ||||
|       ModuleOp module, Operation::operand_type_range input_types, | ||||
|       ArrayRef<StringRef> func_names, int64_t max_iteration); | ||||
| 
 | ||||
|   // Shape propagation for call/control flow ops.
 | ||||
|   LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, | ||||
|                                                     int64_t max_iteration); | ||||
| 
 | ||||
|  private: | ||||
|   // Mapping between ValuePort (which corresponds to an OpResult or smaller,
 | ||||
|   // e.g., first element of OpResult produded) to an Attribute if the ValuePort
 | ||||
|   // corresponds to a constant value.
 | ||||
|   ValuePortResultMap results_; | ||||
|   int64_t graph_version_; | ||||
|   MLIRContext* context_; | ||||
|   Dialect* tf_dialect_; | ||||
| }; | ||||
| 
 | ||||
| ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context) | ||||
|     : graph_version_(graph_version) { | ||||
|   context_ = context; | ||||
|   tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>(); | ||||
| } | ||||
| 
 | ||||
| ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, | ||||
|                                                  InferenceContext* ic) { | ||||
|   LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially ")); | ||||
|   auto rt = result.getType().dyn_cast<RankedTensorType>(); | ||||
|   if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {}; | ||||
|   int dim_size = rt.getDimSize(0); | ||||
| 
 | ||||
|   // Worklist to direct partial evaluation.
 | ||||
|   llvm::SmallVector<ValuePort, 4> worklist; | ||||
|   // The ValuePort evaluated results.
 | ||||
|   // TODO(jpienaar): This could be cached across invocations (e.g., part of some
 | ||||
|   // inference context).
 | ||||
|   ValuePortResultMap evaluated; | ||||
|   // Returns whether a ValuePort has been previously computed.
 | ||||
|   auto has_been_computed = [&evaluated](const ValuePort& port) { | ||||
|     return evaluated.find(port) != evaluated.end(); | ||||
|   }; | ||||
|   // Returns previously computed ValuePort value.
 | ||||
|   auto values = [&evaluated](const ValuePort& port) -> Attribute { | ||||
|     return evaluated[port]; | ||||
|   }; | ||||
|   SmallVector<ValuePort, 4> worklist; | ||||
| 
 | ||||
|   // Simple evaluator that attempts to partially evaluate the input value even
 | ||||
|   // if unable to evaluate the complete output. Below follows a simple stack
 | ||||
| @ -498,7 +563,7 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) { | ||||
|       LLVM_DEBUG(front.print(llvm::errs() << "\nWorklist front ")); | ||||
| 
 | ||||
|       SmallVector<ValuePort, 4> inputs; | ||||
|       auto res = InputsRequiredForOutput(front, has_been_computed, &inputs); | ||||
|       auto res = ComputeInputsRequiredForOutput(front, &inputs); | ||||
|       if (failed(res)) { | ||||
|         // Abort if unable to find which required inputs need to be computed.
 | ||||
|         worklist.clear(); | ||||
| @ -513,16 +578,16 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) { | ||||
|         continue; | ||||
|       } | ||||
| 
 | ||||
|       auto ret = ComputeOutputComponent(front, values); | ||||
|       auto ret = ComputeOutputComponent(front); | ||||
|       if (!ret) continue; | ||||
| 
 | ||||
|       evaluated[front] = ret; | ||||
|       RecordValue(front, ret); | ||||
|       LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = ")); | ||||
| 
 | ||||
|       // If worklist is empty, then this is the root query op.
 | ||||
|       if (worklist.empty()) { | ||||
|         LLVM_DEBUG(llvm::dbgs() << "[root node]\n"); | ||||
|         if (auto dea = ret.dyn_cast<mlir::DenseIntElementsAttr>()) { | ||||
|         if (auto dea = ret.dyn_cast<DenseIntElementsAttr>()) { | ||||
|           if (dea.getNumElements() != 1) { | ||||
|             LLVM_DEBUG(llvm::errs() << "Unexpected number of elements\n"); | ||||
|             return {}; | ||||
| @ -536,14 +601,8 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) { | ||||
|   return ic->MakeShape(dims); | ||||
| } | ||||
| 
 | ||||
| // Performs shape inference on the provided op and return true if the type of
 | ||||
| // at least one result has been changed.
 | ||||
| // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
 | ||||
| // `graph_version` indicates the current GraphDef compatibility versions
 | ||||
| // (the versions field in graph.proto).
 | ||||
| bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, | ||||
|                                   int64_t graph_version) { | ||||
|   assert(tf_dialect == op->getDialect()); | ||||
| bool ShapeInference::InferShapeForSingleOperation(Operation* op) { | ||||
|   assert(tf_dialect_ == op->getDialect()); | ||||
|   // The shape function of these ops sometimes does not propagate subtypes
 | ||||
|   // (handle shapes) for resource and variant types. We use a simple passthrough
 | ||||
|   // to make sure they are preserved in the output.
 | ||||
| @ -555,7 +614,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, | ||||
|   // If no result for this op needs shape inference, we have a fast-path return.
 | ||||
|   // But if the type is a resource/variant, we do not skip it because we might
 | ||||
|   // not have the handle shapes.
 | ||||
|   if (llvm::none_of(op->getResultTypes(), CanBeRefined)) { | ||||
|   if (none_of(op->getResultTypes(), CanBeRefined)) { | ||||
|     LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '" | ||||
|                             << op->getName() << "'.\n"); | ||||
|     return false; | ||||
| @ -570,8 +629,8 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, | ||||
|   // This is necessary to avoid reprocessing the tf.Cast that are inserted at
 | ||||
|   // the end of this function.
 | ||||
|   if (isa<CastOp>(op) && | ||||
|       llvm::all_of(op->getResult(0).getUsers(), [&](Operation* user) { | ||||
|         return user->getDialect() != tf_dialect; | ||||
|       all_of(op->getResult(0).getUsers(), [&](Operation* user) { | ||||
|         return user->getDialect() != tf_dialect_; | ||||
|       })) { | ||||
|     LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF " | ||||
|                                "dialect operation users '" | ||||
| @ -651,7 +710,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, | ||||
|   // Perform the shape inference using an InferenceContext with the input
 | ||||
|   // shapes. This object is abstracting the information that the ShapeInference
 | ||||
|   // function operates on.
 | ||||
|   InferenceContext c(graph_version, *node_def, op_reg_data->op_def, | ||||
|   InferenceContext c(graph_version_, *node_def, op_reg_data->op_def, | ||||
|                      input_shapes, input_tensors, | ||||
|                      /*input_tensors_as_shapes=*/{}, handle_shapes_and_types); | ||||
|   auto status = c.Run(op_reg_data->shape_inference_fn); | ||||
| @ -664,7 +723,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, | ||||
|   // Determine if, during shape computation, the shape functions attempted to
 | ||||
|   // query an input operand as shape where the input was not known/constant.
 | ||||
|   bool requires_inputs = | ||||
|       llvm::any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) { | ||||
|       any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) { | ||||
|         return c.requested_input_tensor_as_partial_shape(input) && | ||||
|                !input_tensors[input]; | ||||
|       }); | ||||
| @ -728,7 +787,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, | ||||
|         new_element_type.isa<TF::VariantType>()) { | ||||
|       auto handle_shapes_types = c.output_handle_shapes_and_types(output); | ||||
|       if (handle_shapes_types) { | ||||
|         llvm::SmallVector<mlir::TensorType, 1> subtypes; | ||||
|         SmallVector<TensorType, 1> subtypes; | ||||
|         OpBuilder b(op); | ||||
|         for (const auto& shape_n_type : *handle_shapes_types) { | ||||
|           Type element_type; | ||||
| @ -748,7 +807,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, | ||||
|     if (result.getType() == new_type) continue; | ||||
|     // Inserts a cast back to the original type if any user is not in the TF
 | ||||
|     // dialect.
 | ||||
|     AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect, | ||||
|     AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect_, | ||||
|                                        result.getType()); | ||||
|     // Finally we inferred the shape and replace the type for this result.
 | ||||
|     result.setType(new_type); | ||||
| @ -760,29 +819,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, | ||||
|   return changed; | ||||
| } | ||||
| 
 | ||||
| // Infers shape on the provided region, including nested ones, iterate until fix
 | ||||
| // point with a limit of max_iteration. Returns success if fix point is reached
 | ||||
| // before max_iteration.
 | ||||
| LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, | ||||
|                                       int64_t max_iteration = 10); | ||||
| 
 | ||||
| // Updates input types and refine shapes inside body of functions that are
 | ||||
| // attached to ControlFlow ops (If/While). These functions include Then/Else
 | ||||
| // branches of IfOp and Cond/Body functions of WhileOp. These functions share
 | ||||
| // following common properties:
 | ||||
| //   1) They are never reused, ie. having a single use in module.
 | ||||
| //   2) Their input types match those of their parent ops (excluding inputs like
 | ||||
| //      predicate).
 | ||||
| // Returns a boolean indicating whether any change has been applied.
 | ||||
| LogicalResult RefineShapeForControlFlowFunc(FuncOp func, | ||||
|                                             llvm::ArrayRef<Type> input_types, | ||||
|                                             int64_t graph_version, | ||||
|                                             int64_t max_iteration) { | ||||
| LogicalResult ShapeInference::RefineShapeForControlFlowFunc( | ||||
|     FuncOp func, ArrayRef<Type> input_types, int64_t max_iteration) { | ||||
|   ModuleOp module = func.getParentOfType<ModuleOp>(); | ||||
|   auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); | ||||
|   int num_uses = std::distance(func_uses->begin(), func_uses->end()); | ||||
|   if (num_uses != 1) { | ||||
|     func.emitWarning(llvm::formatv( | ||||
|     func.emitWarning(formatv( | ||||
|         "expected control flow function {0} to have exactly 1 use, found {1}.", | ||||
|         func.getName(), num_uses)); | ||||
|     return failure(); | ||||
| @ -796,8 +839,7 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func, | ||||
|     arg_and_idx.value().setType(input_types[arg_and_idx.index()]); | ||||
|   } | ||||
| 
 | ||||
|   auto res = | ||||
|       InferShapeUntilFixPoint(&func.getBody(), graph_version, max_iteration); | ||||
|   auto res = InferShapeUntilFixPoint(&func.getBody(), max_iteration); | ||||
|   if (failed(res)) return res; | ||||
| 
 | ||||
|   auto new_return_types = InferShapeForFunctionReturnType(func); | ||||
| @ -809,20 +851,18 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func, | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| LogicalResult PropagateShapeToFunctions( | ||||
| LogicalResult ShapeInference::PropagateShapeToFunctions( | ||||
|     ModuleOp module, Operation::operand_type_range input_types, | ||||
|     llvm::ArrayRef<StringRef> func_names, int64_t graph_version, | ||||
|     int64_t max_iteration) { | ||||
|   bool success = true; | ||||
|     ArrayRef<StringRef> func_names, int64_t max_iteration) { | ||||
|   bool all_succeeded = true; | ||||
|   auto types = llvm::to_vector<4>(input_types); | ||||
|   for (auto func_name : func_names) { | ||||
|     FuncOp func = module.lookupSymbol<FuncOp>(func_name); | ||||
|     if (failed(RefineShapeForControlFlowFunc(func, types, graph_version, | ||||
|                                              max_iteration))) { | ||||
|       success = false; | ||||
|     } | ||||
|     all_succeeded = | ||||
|         succeeded(RefineShapeForControlFlowFunc(func, types, max_iteration)) && | ||||
|         all_succeeded; | ||||
|   } | ||||
|   return mlir::success(success); | ||||
|   return success(all_succeeded); | ||||
| } | ||||
| 
 | ||||
| // If the callee has only one use, propagates any constant operand of call_op to
 | ||||
| @ -842,7 +882,7 @@ void PropagateConstantToCallee(CallOpInterface call_op, | ||||
|     // the constant inside the function.
 | ||||
|     for (auto arg : func.getArguments()) { | ||||
|       auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp(); | ||||
|       if (llvm::isa_and_nonnull<TF::ConstOp>(operand)) { | ||||
|       if (isa_and_nonnull<TF::ConstOp>(operand)) { | ||||
|         arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0)); | ||||
|       } | ||||
|     } | ||||
| @ -861,33 +901,31 @@ void PropagateConstantFromCallee(CallOpInterface call_op, | ||||
|   for (auto retval : | ||||
|        llvm::enumerate(func.front().getTerminator()->getOperands())) { | ||||
|     auto retval_op = retval.value().getDefiningOp(); | ||||
|     if (llvm::isa_and_nonnull<TF::ConstOp>(retval_op)) { | ||||
|     if (isa_and_nonnull<TF::ConstOp>(retval_op)) { | ||||
|       op->getResult(retval.index()) | ||||
|           .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0)); | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, | ||||
|                                                   int64_t graph_version, | ||||
|                                                   int64_t max_iteration) { | ||||
| LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions( | ||||
|     Operation* op, int64_t max_iteration) { | ||||
|   ModuleOp module = op->getParentOfType<ModuleOp>(); | ||||
|   if (auto if_op = dyn_cast<TF::IfOp>(op)) { | ||||
|     return PropagateShapeToFunctions( | ||||
|         module, llvm::drop_begin(if_op.getOperandTypes(), 1), | ||||
|         {if_op.then_branch(), if_op.else_branch()}, graph_version, | ||||
|         max_iteration); | ||||
|         module, drop_begin(if_op.getOperandTypes(), 1), | ||||
|         {if_op.then_branch(), if_op.else_branch()}, max_iteration); | ||||
|   } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) { | ||||
|     return PropagateShapeToFunctions(module, while_op.getOperandTypes(), | ||||
|                                      {while_op.cond(), while_op.body()}, | ||||
|                                      graph_version, max_iteration); | ||||
|                                      max_iteration); | ||||
|   } else if (auto call_op = dyn_cast<CallOpInterface>(op)) { | ||||
|     CallInterfaceCallable callable = call_op.getCallableForCallee(); | ||||
|     if (SymbolRefAttr sym = callable.dyn_cast<SymbolRefAttr>()) { | ||||
|       PropagateConstantToCallee(call_op, sym, module); | ||||
|       if (failed(PropagateShapeToFunctions( | ||||
|               module, call_op.getArgOperands().getTypes(), | ||||
|               {sym.getRootReference()}, graph_version, max_iteration))) { | ||||
|               {sym.getRootReference()}, max_iteration))) { | ||||
|         return failure(); | ||||
|       } | ||||
|       PropagateConstantFromCallee(call_op, sym, module); | ||||
| @ -900,13 +938,10 @@ LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, | ||||
|                                       int64_t max_iteration) { | ||||
|   MLIRContext* ctx = region->getContext(); | ||||
|   Dialect* tf_dialect = ctx->getRegisteredDialect<TensorFlowDialect>(); | ||||
| 
 | ||||
|   // An operation folder that is used to attempt folding before inference.
 | ||||
|   OperationFolder folder(ctx); | ||||
| LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, | ||||
|                                                       int64_t max_iteration) { | ||||
|   // An operation folder that is used to attempt folding before inference._
 | ||||
|   OperationFolder folder(context_); | ||||
|   bool changed = true; | ||||
| 
 | ||||
|   // TODO(aminim): we could have a more efficient traversal by guiding the
 | ||||
| @ -919,14 +954,14 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, | ||||
|                << "Shape inference, iteration " << iteration << "\n"); | ||||
|     region->walk([&](Operation* op) { | ||||
|       if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) { | ||||
|         changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect); | ||||
|         changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_); | ||||
|         // TODO(jpienaar): Debug why we can't just return here. We end up with
 | ||||
|         // additional constant due to the propagation of constant into attached
 | ||||
|         // function if we return already.
 | ||||
|       } | ||||
| 
 | ||||
|       if (op->getDialect() != tf_dialect) { | ||||
|         changed |= InferShapeForNonTFDialectOperation(op, tf_dialect); | ||||
|       if (op->getDialect() != tf_dialect_) { | ||||
|         changed |= InferShapeForNonTFDialectOperation(op, tf_dialect_); | ||||
|         return; | ||||
|       } | ||||
| 
 | ||||
| @ -935,13 +970,12 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, | ||||
| 
 | ||||
|       // Best-effort shape inference in attached functions. Do not return
 | ||||
|       // failure even if it doesn't get to fixed point.
 | ||||
|       if (failed(PropagateShapeIntoAttachedFunctions(op, graph_version, | ||||
|                                                      max_iteration))) { | ||||
|       if (failed(PropagateShapeIntoAttachedFunctions(op, max_iteration))) { | ||||
|         op->emitWarning() << "unable to refine shape of attached function " | ||||
|                              "arguments and bodies"; | ||||
|       } | ||||
| 
 | ||||
|       changed |= InferShapeForSingleOperation(op, tf_dialect, graph_version); | ||||
|       changed |= InferShapeForSingleOperation(op); | ||||
|     }); | ||||
|   } | ||||
| 
 | ||||
| @ -956,44 +990,43 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, | ||||
| LogicalResult InferShapeForFunction(FuncOp func, | ||||
|                                     ArrayRef<ArrayRef<int64_t>> arg_shapes, | ||||
|                                     int64_t graph_version) { | ||||
|   ShapeInference context(graph_version, func.getContext()); | ||||
|   if (arg_shapes.empty()) { | ||||
|     if (failed(InferShapeUntilFixPoint(&func.getBody(), graph_version))) | ||||
|     if (failed(context.InferShapeUntilFixPoint(&func.getBody()))) | ||||
|       return failure(); | ||||
|     // TODO(b/156276510): Verify that it is always fine to refine a function's
 | ||||
|     // return type, as long as we do not change the argument shapes.
 | ||||
|     if (auto return_types = InferShapeForFunctionReturnType(func)) { | ||||
|       func.setType(mlir::FunctionType::get(func.getType().getInputs(), | ||||
|                                            return_types.getValue(), | ||||
|                                            func.getContext())); | ||||
|       func.setType(FunctionType::get(func.getType().getInputs(), | ||||
|                                      return_types.getValue(), | ||||
|                                      func.getContext())); | ||||
|     } | ||||
| 
 | ||||
|     return success(); | ||||
|   } | ||||
|   mlir::FunctionType func_type = func.getType(); | ||||
|   FunctionType func_type = func.getType(); | ||||
|   bool needs_refinement = false; | ||||
|   llvm::SmallVector<mlir::Type, 4> new_arg_types; | ||||
|   SmallVector<Type, 4> new_arg_types; | ||||
|   new_arg_types.reserve(func_type.getNumInputs()); | ||||
| 
 | ||||
|   // Update argument types in-place using the provided arg_shapes.
 | ||||
|   for (size_t i = 0; i < func_type.getNumInputs(); ++i) { | ||||
|     ArrayRef<int64_t> shape = arg_shapes[i]; | ||||
|     mlir::Type element_type; | ||||
|     if (auto input_ty = | ||||
|             func_type.getInput(i).dyn_cast<mlir::RankedTensorType>()) { | ||||
|     Type element_type; | ||||
|     if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) { | ||||
|       if (!input_ty || input_ty.getShape().size() != shape.size()) { | ||||
|         return failure(); | ||||
|       } | ||||
|       element_type = input_ty.getElementType(); | ||||
|     } else { | ||||
|       auto unranked_input_ty = | ||||
|           func_type.getInput(i).dyn_cast<mlir::TensorType>(); | ||||
|       auto unranked_input_ty = func_type.getInput(i).dyn_cast<TensorType>(); | ||||
|       if (!unranked_input_ty) { | ||||
|         return failure(); | ||||
|       } | ||||
|       element_type = unranked_input_ty.getElementType(); | ||||
|     } | ||||
| 
 | ||||
|     auto new_arg_type = mlir::RankedTensorType::get(shape, element_type); | ||||
|     auto new_arg_type = RankedTensorType::get(shape, element_type); | ||||
|     if (new_arg_type != func_type.getInput(i)) { | ||||
|       // If the new type is more detailed, trigger shape inference.
 | ||||
|       func.getArgument(i).setType(new_arg_type); | ||||
| @ -1006,18 +1039,17 @@ LogicalResult InferShapeForFunction(FuncOp func, | ||||
|     return success(); | ||||
|   } | ||||
| 
 | ||||
|   mlir::LogicalResult result = | ||||
|       mlir::TF::InferShapeUntilFixPoint(&func.getBody(), graph_version); | ||||
|   LogicalResult result = context.InferShapeUntilFixPoint(&func.getBody()); | ||||
|   if (failed(result)) { | ||||
|     return failure(); | ||||
|   } | ||||
| 
 | ||||
|   auto return_types = InferShapeForFunctionReturnType(func); | ||||
|   func.setType(mlir::FunctionType::get(new_arg_types, | ||||
|                                        return_types.hasValue() | ||||
|                                            ? return_types.getValue() | ||||
|                                            : func.getType().getResults(), | ||||
|                                        func.getContext())); | ||||
|   func.setType(FunctionType::get(new_arg_types, | ||||
|                                  return_types.hasValue() | ||||
|                                      ? return_types.getValue() | ||||
|                                      : func.getType().getResults(), | ||||
|                                  func.getContext())); | ||||
| 
 | ||||
|   return success(); | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user