Use _input_shapes FunctionDef attr when construction mlir::Func signature.

PiperOrigin-RevId: 267190906
This commit is contained in:
Sean Silva 2019-09-04 11:11:22 -07:00 committed by TensorFlower Gardener
parent 3706ccd34a
commit fdfc30807b
2 changed files with 147 additions and 15 deletions

View File

@ -0,0 +1,110 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s
# Verify that the _input_shapes attribute of the FunctionDef is respected.
# This also checks that the output type is correctly inferred based on
# that.
#CHECK: func @identity_function0(%arg0: tensor<i32>) -> tensor<i32>
node {
name: "Placeholder"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
experimental_debug_info {
}
}
node {
name: "Placeholder_1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
experimental_debug_info {
}
}
node {
name: "If"
op: "If"
input: "Placeholder"
input: "Placeholder_1"
attr {
key: "Tcond"
value {
type: DT_BOOL
}
}
attr {
key: "Tin"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "else_branch"
value {
func {
name: "identity_function"
}
}
}
attr {
key: "then_branch"
value {
func {
name: "identity_function"
}
}
}
experimental_debug_info {
}
}
library {
function {
signature {
name: "identity_function"
input_arg {
name: "identity_input"
type: DT_INT32
}
output_arg {
name: "identity_output"
type: DT_INT32
}
}
ret {
key: "identity_output"
value: "identity_input"
}
attr {
key: "_input_shapes"
value {
list {
shape {
}
}
}
}
}
}
versions {
producer: 29
min_consumer: 12
}

View File

@ -476,7 +476,8 @@ Status ImporterBase::AddNodesToShapeRefiner() {
auto it = specs_.inputs.find(node->name());
if (it != specs_.inputs.end()) {
auto node_name = node->op_def().name();
if (node_name != "Placeholder" && node_name != "LegacyFedInput") {
if (node_name != "Placeholder" && node_name != "LegacyFedInput" &&
node_name != "_Arg") {
// We do not handle the case where the input node has multple outputs
if (node->num_outputs() > 1) {
return errors::FailedPrecondition(absl::StrCat(
@ -845,10 +846,28 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr));
}
// Converts the graph to a MLIR function and adds it to the module. Uses the
// default node spec without any inputs or outputs as the function graph has
// special '_Arg' and '_Retval' ops for argument and return values.
// Converts the graph to a MLIR function and adds it to the module.
// We populate the NodeSpec so that all the _Arg ops get their shape
// added correctly.
NodeSpecs specs;
for (const auto& name_and_value : func_def->attr()) {
if (name_and_value.first == "_input_shapes") {
auto& list = name_and_value.second.list();
auto& signature = func_def->signature();
for (int i = 0; i < list.shape_size(); i++) {
auto& input_arg = signature.input_arg(i);
auto& array_info = specs.inputs[input_arg.name()];
array_info.imported_dtype = input_arg.type();
array_info.shape = list.shape(i);
// TODO(b/140464702): These fields should not be exposed here.
// Seems like a layering violation. Initialize them anyway.
array_info.final_dtype = input_arg.type();
array_info.min_value = 0.0;
array_info.max_value = 0.0;
}
}
}
ImporterBase child_importer(graph_flib_, debug_info_, specs, module_,
tf_name_to_mlir_name_);
TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph));
@ -1399,16 +1418,22 @@ StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
const FunctionBody& fbody) {
mlir::Builder builder(context_);
// The FunctionBody contains a graph with a single-output _Arg node for each
// function argument and a single-input _Retval node for each function return
// value.
//
// We already populated the ShapeRefiner with all the information about the
// shapes of these graph edges, so we just query it to build the corresponding
// MLIR function type signature.
llvm::SmallVector<mlir::Type, 4> arg_types;
arg_types.reserve(fbody.arg_types.size());
for (auto dataType : fbody.arg_types) {
mlir::Type element_type;
TF_RETURN_IF_ERROR(
::tensorflow::ConvertDataType(dataType, builder, &element_type));
// TODO(hinsu): Derive shape of function arguments based on shapes available
// at call sites of this function. That way it is possible to have a
// partially known shape in some cases instead of unranked tensor types.
arg_types.push_back(builder.getTensorType(element_type));
for (auto arg : fbody.arg_nodes) {
// Find node in the graph using the node id instead of using `arg` directly
// because the graph has been cloned.
auto* node = graph_->FindNodeId(arg->id());
TF_ASSIGN_OR_RETURN(auto type, InferOutputType(*node, /*idx=*/0, builder));
arg_types.push_back(type);
}
llvm::SmallVector<mlir::Type, 4> ret_types;
@ -1417,9 +1442,6 @@ StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
// Find node in the graph using the node id instead of using `ret` directly
// because the graph has been cloned.
auto* node = graph_->FindNodeId(ret->id());
// Return type of the function is type of the only input of the respective
// return node in the function.
TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder));
ret_types.push_back(type);
}