Use _input_shapes FunctionDef attr when construction mlir::Func signature.
PiperOrigin-RevId: 267190906
This commit is contained in:
parent
3706ccd34a
commit
fdfc30807b
@ -0,0 +1,110 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s
|
||||
|
||||
# Verify that the _input_shapes attribute of the FunctionDef is respected.
|
||||
# This also checks that the output type is correctly inferred based on
|
||||
# that.
|
||||
#CHECK: func @identity_function0(%arg0: tensor<i32>) -> tensor<i32>
|
||||
|
||||
node {
|
||||
name: "Placeholder"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Placeholder_1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "If"
|
||||
op: "If"
|
||||
input: "Placeholder"
|
||||
input: "Placeholder_1"
|
||||
attr {
|
||||
key: "Tcond"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tin"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tout"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "else_branch"
|
||||
value {
|
||||
func {
|
||||
name: "identity_function"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "then_branch"
|
||||
value {
|
||||
func {
|
||||
name: "identity_function"
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "identity_function"
|
||||
input_arg {
|
||||
name: "identity_input"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "identity_output"
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "identity_output"
|
||||
value: "identity_input"
|
||||
}
|
||||
attr {
|
||||
key: "_input_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 29
|
||||
min_consumer: 12
|
||||
}
|
||||
|
@ -476,7 +476,8 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
||||
auto it = specs_.inputs.find(node->name());
|
||||
if (it != specs_.inputs.end()) {
|
||||
auto node_name = node->op_def().name();
|
||||
if (node_name != "Placeholder" && node_name != "LegacyFedInput") {
|
||||
if (node_name != "Placeholder" && node_name != "LegacyFedInput" &&
|
||||
node_name != "_Arg") {
|
||||
// We do not handle the case where the input node has multple outputs
|
||||
if (node->num_outputs() > 1) {
|
||||
return errors::FailedPrecondition(absl::StrCat(
|
||||
@ -845,10 +846,28 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
|
||||
attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr));
|
||||
}
|
||||
|
||||
// Converts the graph to a MLIR function and adds it to the module. Uses the
|
||||
// default node spec without any inputs or outputs as the function graph has
|
||||
// special '_Arg' and '_Retval' ops for argument and return values.
|
||||
// Converts the graph to a MLIR function and adds it to the module.
|
||||
// We populate the NodeSpec so that all the _Arg ops get their shape
|
||||
// added correctly.
|
||||
NodeSpecs specs;
|
||||
for (const auto& name_and_value : func_def->attr()) {
|
||||
if (name_and_value.first == "_input_shapes") {
|
||||
auto& list = name_and_value.second.list();
|
||||
auto& signature = func_def->signature();
|
||||
for (int i = 0; i < list.shape_size(); i++) {
|
||||
auto& input_arg = signature.input_arg(i);
|
||||
auto& array_info = specs.inputs[input_arg.name()];
|
||||
array_info.imported_dtype = input_arg.type();
|
||||
array_info.shape = list.shape(i);
|
||||
// TODO(b/140464702): These fields should not be exposed here.
|
||||
// Seems like a layering violation. Initialize them anyway.
|
||||
array_info.final_dtype = input_arg.type();
|
||||
array_info.min_value = 0.0;
|
||||
array_info.max_value = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ImporterBase child_importer(graph_flib_, debug_info_, specs, module_,
|
||||
tf_name_to_mlir_name_);
|
||||
TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph));
|
||||
@ -1399,16 +1418,22 @@ StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
|
||||
const FunctionBody& fbody) {
|
||||
mlir::Builder builder(context_);
|
||||
|
||||
// The FunctionBody contains a graph with a single-output _Arg node for each
|
||||
// function argument and a single-input _Retval node for each function return
|
||||
// value.
|
||||
//
|
||||
// We already populated the ShapeRefiner with all the information about the
|
||||
// shapes of these graph edges, so we just query it to build the corresponding
|
||||
// MLIR function type signature.
|
||||
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||
arg_types.reserve(fbody.arg_types.size());
|
||||
for (auto dataType : fbody.arg_types) {
|
||||
mlir::Type element_type;
|
||||
TF_RETURN_IF_ERROR(
|
||||
::tensorflow::ConvertDataType(dataType, builder, &element_type));
|
||||
// TODO(hinsu): Derive shape of function arguments based on shapes available
|
||||
// at call sites of this function. That way it is possible to have a
|
||||
// partially known shape in some cases instead of unranked tensor types.
|
||||
arg_types.push_back(builder.getTensorType(element_type));
|
||||
for (auto arg : fbody.arg_nodes) {
|
||||
// Find node in the graph using the node id instead of using `arg` directly
|
||||
// because the graph has been cloned.
|
||||
auto* node = graph_->FindNodeId(arg->id());
|
||||
TF_ASSIGN_OR_RETURN(auto type, InferOutputType(*node, /*idx=*/0, builder));
|
||||
arg_types.push_back(type);
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
@ -1417,9 +1442,6 @@ StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
|
||||
// Find node in the graph using the node id instead of using `ret` directly
|
||||
// because the graph has been cloned.
|
||||
auto* node = graph_->FindNodeId(ret->id());
|
||||
|
||||
// Return type of the function is type of the only input of the respective
|
||||
// return node in the function.
|
||||
TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder));
|
||||
ret_types.push_back(type);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user