diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc index 8a187cf43a8..92571148cff 100644 --- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc @@ -29,11 +29,25 @@ limitations under the License. #include "tfrt/tensor/dense_host_tensor_view.h" namespace tensorflow { +namespace { -void MapFunctionGlobalTensorCapturesFromTFSavedModelMLIR( +llvm::StringRef ProcessIndexPath(mlir::ArrayAttr index_path) { + if (index_path.size() == 1 && index_path[0].isa()) { + // TODO(chky): Support cases where index_path is not a single string. + return index_path[0].cast().getValue(); + } + return ""; +} + +} // namespace + +void MapFunctionSignaturesFromTFSavedModelMLIR( mlir::ModuleOp module, llvm::function_ref> + input_names_and_devices, + llvm::ArrayRef output_names, llvm::ArrayRef global_tensors)> map_fn) { // Create global_tensors for each functions. @@ -44,17 +58,38 @@ void MapFunctionGlobalTensorCapturesFromTFSavedModelMLIR( auto func_names = mlir::tf_saved_model::GetExportedNames(func); if (func_names.empty()) return; - // Here we walk through each arguments and find out the variables used by - // this function. + // Here we walk through each arguments and find out the input/output names, + // and input devices, variables used by this function. + llvm::SmallVector, 4> + input_names_and_devices; llvm::SmallVector global_tensors; for (unsigned i = 0, e = func.getNumArguments(); i != e; ++i) { + if (auto input_index_path = func.getArgAttrOfType( + i, "tf_saved_model.index_path")) { + std::pair name_and_device; + name_and_device.first = ProcessIndexPath(input_index_path); + if (auto input_device = + func.getArgAttrOfType(i, "tf.device")) { + name_and_device.second = input_device.getValue(); + } + input_names_and_devices.push_back(name_and_device); + } if (auto variable = mlir::tf_saved_model::LookupBoundInput(func, i, symbol_table)) { global_tensors.push_back(variable); } } - for (auto func_name : func_names) map_fn(func_name, global_tensors); + llvm::SmallVector output_names; + for (unsigned i = 0, e = func.getNumResults(); i != e; ++i) { + if (auto output_index_path = func.getResultAttrOfType( + i, "tf_saved_model.index_path")) { + output_names.push_back(ProcessIndexPath(output_index_path)); + } + } + + for (auto func_name : func_names) + map_fn(func_name, input_names_and_devices, output_names, global_tensors); }); } diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h index de24ea20958..06a6c5a22f9 100644 --- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h +++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h @@ -57,12 +57,15 @@ struct TFRTSavedModelCompileOptions { std::string force_data_format; }; -// Map captured global tensors for each function. -void MapFunctionGlobalTensorCapturesFromTFSavedModelMLIR( +// Map signatures (eg. input/output names, variables) for each function. +void MapFunctionSignaturesFromTFSavedModelMLIR( mlir::ModuleOp module, - llvm::function_ref< - void(llvm::StringRef func_name, - llvm::ArrayRef captures)> + llvm::function_ref> + input_names_and_devices, + llvm::ArrayRef output_names, + llvm::ArrayRef global_tensors)> map_fn); // Compile MLIR in TF saved model dialect into BEF.