Import shapes from attributes for tf.IteratorGetNext, tf.IteratorGetNextSync, tf.MultiDeviceIteratorGetNextFromShard, tf.InfeedDequeueTuple, and tf.InfeedDequeue when shape inference on import is not enabled, and disable shape inference on import in MlirFunctionOptimizationPass.

Shape inference on import is deprecated and these ops have trivial shapes to import from their attributes, which due to them being derived, are lost post import.

PiperOrigin-RevId: 345698723
Change-Id: Iade6693fb692652bf170245a33d17c1ad2f4fadd
This commit is contained in:
Andy Ly 2020-12-04 09:56:02 -08:00 committed by TensorFlower Gardener
parent ee82f96f12
commit 2022bf35a4
3 changed files with 365 additions and 0 deletions

View File

@ -152,6 +152,12 @@ Status MlirFunctionOptimizationPass::Run(
import_config.graph_as_function = true;
import_config.control_outputs = *control_ret_node_names;
import_config.upgrade_legacy = true;
// Disable shape inference during import as some TensorFlow op fails during
// shape inference with dynamic shaped operands. This in turn causes the
// import to fail. Shape inference during import is going to be removed and
// the shape inference pass is run early in the pass pipeline, shape inference
// during import is not necessary.
import_config.enable_shape_inference = false;
auto module_ref_status = ConvertGraphToMlir(**graph, debug_info, *flib_def,
import_config, &context);

View File

@ -0,0 +1,334 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false -tf-graph-as-function %s -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
node {
name: "args_0"
op: "_Arg"
attr {
key: "T"
value {
type: DT_RESOURCE
}
}
attr {
key: "index"
value {
i: 0
}
}
}
node {
name: "args_1"
op: "_Arg"
attr {
key: "T"
value {
type: DT_RESOURCE
}
}
attr {
key: "index"
value {
i: 1
}
}
}
node {
name: "args_2"
op: "_Arg"
attr {
key: "T"
value {
type: DT_RESOURCE
}
}
attr {
key: "index"
value {
i: 2
}
}
}
node {
name: "args_3"
op: "_Arg"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "index"
value {
i: 3
}
}
}
node {
name: "args_4"
op: "_Arg"
attr {
key: "T"
value {
type: DT_INT64
}
}
attr {
key: "index"
value {
i: 4
}
}
}
node {
name: "IteratorGetNext"
op: "IteratorGetNext"
input: "args_0"
attr {
key: "output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 8
}
}
shape {
dim {
size: 2
}
dim {
size: -1
}
dim {
size: 16
}
}
shape {
unknown_rank: true
}
}
}
}
attr {
key: "output_types"
value {
list {
type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
}
node {
name: "IteratorGetNextSync"
op: "IteratorGetNextSync"
input: "args_0"
attr {
key: "output_shapes"
value {
list {
shape {
unknown_rank: true
}
shape {
dim {
size: 3
}
dim {
size: 24
}
}
shape {
dim {
size: -1
}
dim {
size: 4
}
dim {
size: 32
}
}
}
}
}
attr {
key: "output_types"
value {
list {
type: DT_INT16
type: DT_INT32
type: DT_INT64
}
}
}
}
node {
name: "MultiDeviceIteratorGetNextFromShard"
op: "MultiDeviceIteratorGetNextFromShard"
input: "args_2"
input: "args_3"
input: "args_4"
attr {
key: "output_shapes"
value {
list {
shape {
dim {
size: 5
}
dim {
size: 40
}
}
shape {
unknown_rank: true
}
shape {
dim {
size: 6
}
dim {
size: 48
}
dim {
size: -1
}
}
}
}
}
attr {
key: "output_types"
value {
list {
type: DT_HALF
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
}
node {
name: "InfeedDequeueTuple"
op: "InfeedDequeueTuple"
attr {
key: "shapes"
value {
list {
shape {
dim {
size: -1
}
dim {
size: -1
}
dim {
size: -1
}
}
shape {
unknown_rank: true
}
shape {
dim {
size: 7
}
dim {
size: 56
}
}
}
}
}
attr {
key: "dtypes"
value {
list {
type: DT_UINT16
type: DT_UINT32
type: DT_UINT64
}
}
}
}
node {
name: "InfeedDequeue_0"
op: "InfeedDequeue"
attr {
key: "shape"
value {
shape {
dim {
size: -1
}
dim {
size: 8
}
dim {
size: -1
}
}
}
}
attr {
key: "dtype"
value {
type: DT_INT8
}
}
}
node {
name: "InfeedDequeue_1"
op: "InfeedDequeue"
attr {
key: "shape"
value {
shape {
dim {
size: 8
}
dim {
size: 64
}
}
}
}
attr {
key: "dtype"
value {
type: DT_UINT8
}
}
}
node {
name: "InfeedDequeue_2"
op: "InfeedDequeue"
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
}
# CHECK-DAG: tf.IteratorGetNext{{.+}}-> (tensor<1x8xbf16>, tensor<2x?x16xf32>, tensor<*xf64>)
# CHECK-DAG: tf.IteratorGetNextSync{{.+}}-> (tensor<*xi16>, tensor<3x24xi32>, tensor<?x4x32xi64>)
# CHECK-DAG: tf.MultiDeviceIteratorGetNextFromShard{{.+}}-> (tensor<5x40xf16>, tensor<*xcomplex<f32>>, tensor<6x48x?xcomplex<f64>>)
# CHECK-DAG: tf.InfeedDequeueTuple{{.+}}-> (tensor<?x?x?xui16>, tensor<*xui32>, tensor<7x56xui64>)
# CHECK-DAG: tf.InfeedDequeue{{.+}}-> tensor<?x8x?xi8>
# CHECK-DAG: tf.InfeedDequeue{{.+}}-> tensor<8x64xui8>
# CHECK-DAG: tf.InfeedDequeue{{.+}}-> tensor<*xi1>

View File

@ -980,6 +980,31 @@ StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx,
}
}
auto type_from_array_attr = [&node, &idx, &builder](
absl::string_view output_shape_attr,
absl::string_view element_type_attr) {
auto* output_shapes = node.attrs().Find(output_shape_attr);
auto* element_types = node.attrs().Find(element_type_attr);
const auto& output_shape = output_shapes->list().shape(idx);
const auto& element_type = element_types->list().type(idx);
return ConvertToMlirTensorType(output_shape, element_type, &builder);
};
if (node.type_string() == "IteratorGetNext" ||
node.type_string() == "IteratorGetNextSync" ||
node.type_string() == "MultiDeviceIteratorGetNextFromShard")
return type_from_array_attr("output_shapes", "output_types");
if (node.type_string() == "InfeedDequeueTuple")
return type_from_array_attr("shapes", "dtypes");
if (node.type_string() == "InfeedDequeue") {
assert(idx == 0);
const auto& output_shape = node.attrs().Find("shape")->shape();
const auto& element_type = node.attrs().Find("dtype")->type();
return ConvertToMlirTensorType(output_shape, element_type, &builder);
}
// Returns a simple, more conservative unranked tensor type.
auto default_type = [&]() -> StatusOr<mlir::Type> {
mlir::Type element_type;