Use _input_shapes FunctionDef attr when construction mlir::Func signature.
PiperOrigin-RevId: 267190906
This commit is contained in:
parent
3706ccd34a
commit
fdfc30807b
tensorflow/compiler/mlir/tensorflow
@ -0,0 +1,110 @@
|
|||||||
|
# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s
|
||||||
|
|
||||||
|
# Verify that the _input_shapes attribute of the FunctionDef is respected.
|
||||||
|
# This also checks that the output type is correctly inferred based on
|
||||||
|
# that.
|
||||||
|
#CHECK: func @identity_function0(%arg0: tensor<i32>) -> tensor<i32>
|
||||||
|
|
||||||
|
node {
|
||||||
|
name: "Placeholder"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_BOOL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Placeholder_1"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "If"
|
||||||
|
op: "If"
|
||||||
|
input: "Placeholder"
|
||||||
|
input: "Placeholder_1"
|
||||||
|
attr {
|
||||||
|
key: "Tcond"
|
||||||
|
value {
|
||||||
|
type: DT_BOOL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "Tin"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "Tout"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "else_branch"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "identity_function"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "then_branch"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "identity_function"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
library {
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "identity_function"
|
||||||
|
input_arg {
|
||||||
|
name: "identity_input"
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "identity_output"
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret {
|
||||||
|
key: "identity_output"
|
||||||
|
value: "identity_input"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "_input_shapes"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
versions {
|
||||||
|
producer: 29
|
||||||
|
min_consumer: 12
|
||||||
|
}
|
||||||
|
|
@ -476,7 +476,8 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
|||||||
auto it = specs_.inputs.find(node->name());
|
auto it = specs_.inputs.find(node->name());
|
||||||
if (it != specs_.inputs.end()) {
|
if (it != specs_.inputs.end()) {
|
||||||
auto node_name = node->op_def().name();
|
auto node_name = node->op_def().name();
|
||||||
if (node_name != "Placeholder" && node_name != "LegacyFedInput") {
|
if (node_name != "Placeholder" && node_name != "LegacyFedInput" &&
|
||||||
|
node_name != "_Arg") {
|
||||||
// We do not handle the case where the input node has multple outputs
|
// We do not handle the case where the input node has multple outputs
|
||||||
if (node->num_outputs() > 1) {
|
if (node->num_outputs() > 1) {
|
||||||
return errors::FailedPrecondition(absl::StrCat(
|
return errors::FailedPrecondition(absl::StrCat(
|
||||||
@ -845,10 +846,28 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
|
|||||||
attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr));
|
attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts the graph to a MLIR function and adds it to the module. Uses the
|
// Converts the graph to a MLIR function and adds it to the module.
|
||||||
// default node spec without any inputs or outputs as the function graph has
|
// We populate the NodeSpec so that all the _Arg ops get their shape
|
||||||
// special '_Arg' and '_Retval' ops for argument and return values.
|
// added correctly.
|
||||||
NodeSpecs specs;
|
NodeSpecs specs;
|
||||||
|
for (const auto& name_and_value : func_def->attr()) {
|
||||||
|
if (name_and_value.first == "_input_shapes") {
|
||||||
|
auto& list = name_and_value.second.list();
|
||||||
|
auto& signature = func_def->signature();
|
||||||
|
for (int i = 0; i < list.shape_size(); i++) {
|
||||||
|
auto& input_arg = signature.input_arg(i);
|
||||||
|
auto& array_info = specs.inputs[input_arg.name()];
|
||||||
|
array_info.imported_dtype = input_arg.type();
|
||||||
|
array_info.shape = list.shape(i);
|
||||||
|
// TODO(b/140464702): These fields should not be exposed here.
|
||||||
|
// Seems like a layering violation. Initialize them anyway.
|
||||||
|
array_info.final_dtype = input_arg.type();
|
||||||
|
array_info.min_value = 0.0;
|
||||||
|
array_info.max_value = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ImporterBase child_importer(graph_flib_, debug_info_, specs, module_,
|
ImporterBase child_importer(graph_flib_, debug_info_, specs, module_,
|
||||||
tf_name_to_mlir_name_);
|
tf_name_to_mlir_name_);
|
||||||
TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph));
|
TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph));
|
||||||
@ -1399,16 +1418,22 @@ StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
|
|||||||
const FunctionBody& fbody) {
|
const FunctionBody& fbody) {
|
||||||
mlir::Builder builder(context_);
|
mlir::Builder builder(context_);
|
||||||
|
|
||||||
|
// The FunctionBody contains a graph with a single-output _Arg node for each
|
||||||
|
// function argument and a single-input _Retval node for each function return
|
||||||
|
// value.
|
||||||
|
//
|
||||||
|
// We already populated the ShapeRefiner with all the information about the
|
||||||
|
// shapes of these graph edges, so we just query it to build the corresponding
|
||||||
|
// MLIR function type signature.
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||||
arg_types.reserve(fbody.arg_types.size());
|
arg_types.reserve(fbody.arg_types.size());
|
||||||
for (auto dataType : fbody.arg_types) {
|
for (auto arg : fbody.arg_nodes) {
|
||||||
mlir::Type element_type;
|
// Find node in the graph using the node id instead of using `arg` directly
|
||||||
TF_RETURN_IF_ERROR(
|
// because the graph has been cloned.
|
||||||
::tensorflow::ConvertDataType(dataType, builder, &element_type));
|
auto* node = graph_->FindNodeId(arg->id());
|
||||||
// TODO(hinsu): Derive shape of function arguments based on shapes available
|
TF_ASSIGN_OR_RETURN(auto type, InferOutputType(*node, /*idx=*/0, builder));
|
||||||
// at call sites of this function. That way it is possible to have a
|
arg_types.push_back(type);
|
||||||
// partially known shape in some cases instead of unranked tensor types.
|
|
||||||
arg_types.push_back(builder.getTensorType(element_type));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||||
@ -1417,9 +1442,6 @@ StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
|
|||||||
// Find node in the graph using the node id instead of using `ret` directly
|
// Find node in the graph using the node id instead of using `ret` directly
|
||||||
// because the graph has been cloned.
|
// because the graph has been cloned.
|
||||||
auto* node = graph_->FindNodeId(ret->id());
|
auto* node = graph_->FindNodeId(ret->id());
|
||||||
|
|
||||||
// Return type of the function is type of the only input of the respective
|
|
||||||
// return node in the function.
|
|
||||||
TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder));
|
TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder));
|
||||||
ret_types.push_back(type);
|
ret_types.push_back(type);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user