Internal change
PiperOrigin-RevId: 311003458 Change-Id: I1a5923edadf3a0101a43dd6dd605c37402b017e4
This commit is contained in:
parent
7c09d15f9f
commit
ce43a59c72
@ -29,11 +29,25 @@ limitations under the License.
|
|||||||
#include "tfrt/tensor/dense_host_tensor_view.h"
|
#include "tfrt/tensor/dense_host_tensor_view.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
void MapFunctionGlobalTensorCapturesFromTFSavedModelMLIR(
|
llvm::StringRef ProcessIndexPath(mlir::ArrayAttr index_path) {
|
||||||
|
if (index_path.size() == 1 && index_path[0].isa<mlir::StringAttr>()) {
|
||||||
|
// TODO(chky): Support cases where index_path is not a single string.
|
||||||
|
return index_path[0].cast<mlir::StringAttr>().getValue();
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void MapFunctionSignaturesFromTFSavedModelMLIR(
|
||||||
mlir::ModuleOp module,
|
mlir::ModuleOp module,
|
||||||
llvm::function_ref<void(
|
llvm::function_ref<void(
|
||||||
llvm::StringRef func_name,
|
llvm::StringRef func_name,
|
||||||
|
llvm::ArrayRef<std::pair<llvm::StringRef, llvm::StringRef>>
|
||||||
|
input_names_and_devices,
|
||||||
|
llvm::ArrayRef<llvm::StringRef> output_names,
|
||||||
llvm::ArrayRef<mlir::tf_saved_model::GlobalTensorOp> global_tensors)>
|
llvm::ArrayRef<mlir::tf_saved_model::GlobalTensorOp> global_tensors)>
|
||||||
map_fn) {
|
map_fn) {
|
||||||
// Create global_tensors for each functions.
|
// Create global_tensors for each functions.
|
||||||
@ -44,17 +58,38 @@ void MapFunctionGlobalTensorCapturesFromTFSavedModelMLIR(
|
|||||||
auto func_names = mlir::tf_saved_model::GetExportedNames(func);
|
auto func_names = mlir::tf_saved_model::GetExportedNames(func);
|
||||||
if (func_names.empty()) return;
|
if (func_names.empty()) return;
|
||||||
|
|
||||||
// Here we walk through each arguments and find out the variables used by
|
// Here we walk through each arguments and find out the input/output names,
|
||||||
// this function.
|
// and input devices, variables used by this function.
|
||||||
|
llvm::SmallVector<std::pair<llvm::StringRef, llvm::StringRef>, 4>
|
||||||
|
input_names_and_devices;
|
||||||
llvm::SmallVector<mlir::tf_saved_model::GlobalTensorOp, 4> global_tensors;
|
llvm::SmallVector<mlir::tf_saved_model::GlobalTensorOp, 4> global_tensors;
|
||||||
for (unsigned i = 0, e = func.getNumArguments(); i != e; ++i) {
|
for (unsigned i = 0, e = func.getNumArguments(); i != e; ++i) {
|
||||||
|
if (auto input_index_path = func.getArgAttrOfType<mlir::ArrayAttr>(
|
||||||
|
i, "tf_saved_model.index_path")) {
|
||||||
|
std::pair<llvm::StringRef, llvm::StringRef> name_and_device;
|
||||||
|
name_and_device.first = ProcessIndexPath(input_index_path);
|
||||||
|
if (auto input_device =
|
||||||
|
func.getArgAttrOfType<mlir::StringAttr>(i, "tf.device")) {
|
||||||
|
name_and_device.second = input_device.getValue();
|
||||||
|
}
|
||||||
|
input_names_and_devices.push_back(name_and_device);
|
||||||
|
}
|
||||||
if (auto variable =
|
if (auto variable =
|
||||||
mlir::tf_saved_model::LookupBoundInput(func, i, symbol_table)) {
|
mlir::tf_saved_model::LookupBoundInput(func, i, symbol_table)) {
|
||||||
global_tensors.push_back(variable);
|
global_tensors.push_back(variable);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto func_name : func_names) map_fn(func_name, global_tensors);
|
llvm::SmallVector<llvm::StringRef, 4> output_names;
|
||||||
|
for (unsigned i = 0, e = func.getNumResults(); i != e; ++i) {
|
||||||
|
if (auto output_index_path = func.getResultAttrOfType<mlir::ArrayAttr>(
|
||||||
|
i, "tf_saved_model.index_path")) {
|
||||||
|
output_names.push_back(ProcessIndexPath(output_index_path));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto func_name : func_names)
|
||||||
|
map_fn(func_name, input_names_and_devices, output_names, global_tensors);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,12 +57,15 @@ struct TFRTSavedModelCompileOptions {
|
|||||||
std::string force_data_format;
|
std::string force_data_format;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Map captured global tensors for each function.
|
// Map signatures (eg. input/output names, variables) for each function.
|
||||||
void MapFunctionGlobalTensorCapturesFromTFSavedModelMLIR(
|
void MapFunctionSignaturesFromTFSavedModelMLIR(
|
||||||
mlir::ModuleOp module,
|
mlir::ModuleOp module,
|
||||||
llvm::function_ref<
|
llvm::function_ref<void(
|
||||||
void(llvm::StringRef func_name,
|
llvm::StringRef func_name,
|
||||||
llvm::ArrayRef<mlir::tf_saved_model::GlobalTensorOp> captures)>
|
llvm::ArrayRef<std::pair<llvm::StringRef, llvm::StringRef>>
|
||||||
|
input_names_and_devices,
|
||||||
|
llvm::ArrayRef<llvm::StringRef> output_names,
|
||||||
|
llvm::ArrayRef<mlir::tf_saved_model::GlobalTensorOp> global_tensors)>
|
||||||
map_fn);
|
map_fn);
|
||||||
|
|
||||||
// Compile MLIR in TF saved model dialect into BEF.
|
// Compile MLIR in TF saved model dialect into BEF.
|
||||||
|
Loading…
Reference in New Issue
Block a user