Import shapes from attributes for tf.IteratorGetNext, tf.IteratorGetNextSync, tf.MultiDeviceIteratorGetNextFromShard, tf.InfeedDequeueTuple, and tf.InfeedDequeue when shape inference on import is not enabled, and disable shape inference on import in MlirFunctionOptimizationPass.
Shape inference on import is deprecated and these ops have trivial shapes to import from their attributes, which due to them being derived, are lost post import. PiperOrigin-RevId: 345698723 Change-Id: Iade6693fb692652bf170245a33d17c1ad2f4fadd
This commit is contained in:
parent
ee82f96f12
commit
2022bf35a4
@ -152,6 +152,12 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
import_config.graph_as_function = true;
|
||||
import_config.control_outputs = *control_ret_node_names;
|
||||
import_config.upgrade_legacy = true;
|
||||
// Disable shape inference during import as some TensorFlow op fails during
|
||||
// shape inference with dynamic shaped operands. This in turn causes the
|
||||
// import to fail. Shape inference during import is going to be removed and
|
||||
// the shape inference pass is run early in the pass pipeline, shape inference
|
||||
// during import is not necessary.
|
||||
import_config.enable_shape_inference = false;
|
||||
|
||||
auto module_ref_status = ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||
import_config, &context);
|
||||
|
@ -0,0 +1,334 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false -tf-graph-as-function %s -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
|
||||
|
||||
node {
|
||||
name: "args_0"
|
||||
op: "_Arg"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "args_1"
|
||||
op: "_Arg"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value {
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "args_2"
|
||||
op: "_Arg"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value {
|
||||
i: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "args_3"
|
||||
op: "_Arg"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value {
|
||||
i: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "args_4"
|
||||
op: "_Arg"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value {
|
||||
i: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "IteratorGetNext"
|
||||
op: "IteratorGetNext"
|
||||
input: "args_0"
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
dim {
|
||||
size: 8
|
||||
}
|
||||
}
|
||||
shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
dim {
|
||||
size: 16
|
||||
}
|
||||
}
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "output_types"
|
||||
value {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "IteratorGetNextSync"
|
||||
op: "IteratorGetNextSync"
|
||||
input: "args_0"
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
shape {
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
dim {
|
||||
size: 24
|
||||
}
|
||||
}
|
||||
shape {
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
dim {
|
||||
size: 4
|
||||
}
|
||||
dim {
|
||||
size: 32
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "output_types"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT16
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "MultiDeviceIteratorGetNextFromShard"
|
||||
op: "MultiDeviceIteratorGetNextFromShard"
|
||||
input: "args_2"
|
||||
input: "args_3"
|
||||
input: "args_4"
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
dim {
|
||||
size: 5
|
||||
}
|
||||
dim {
|
||||
size: 40
|
||||
}
|
||||
}
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
shape {
|
||||
dim {
|
||||
size: 6
|
||||
}
|
||||
dim {
|
||||
size: 48
|
||||
}
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "output_types"
|
||||
value {
|
||||
list {
|
||||
type: DT_HALF
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "InfeedDequeueTuple"
|
||||
op: "InfeedDequeueTuple"
|
||||
attr {
|
||||
key: "shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
}
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
shape {
|
||||
dim {
|
||||
size: 7
|
||||
}
|
||||
dim {
|
||||
size: 56
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtypes"
|
||||
value {
|
||||
list {
|
||||
type: DT_UINT16
|
||||
type: DT_UINT32
|
||||
type: DT_UINT64
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "InfeedDequeue_0"
|
||||
op: "InfeedDequeue"
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
dim {
|
||||
size: 8
|
||||
}
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT8
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "InfeedDequeue_1"
|
||||
op: "InfeedDequeue"
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 8
|
||||
}
|
||||
dim {
|
||||
size: 64
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_UINT8
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "InfeedDequeue_2"
|
||||
op: "InfeedDequeue"
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# CHECK-DAG: tf.IteratorGetNext{{.+}}-> (tensor<1x8xbf16>, tensor<2x?x16xf32>, tensor<*xf64>)
|
||||
# CHECK-DAG: tf.IteratorGetNextSync{{.+}}-> (tensor<*xi16>, tensor<3x24xi32>, tensor<?x4x32xi64>)
|
||||
# CHECK-DAG: tf.MultiDeviceIteratorGetNextFromShard{{.+}}-> (tensor<5x40xf16>, tensor<*xcomplex<f32>>, tensor<6x48x?xcomplex<f64>>)
|
||||
# CHECK-DAG: tf.InfeedDequeueTuple{{.+}}-> (tensor<?x?x?xui16>, tensor<*xui32>, tensor<7x56xui64>)
|
||||
# CHECK-DAG: tf.InfeedDequeue{{.+}}-> tensor<?x8x?xi8>
|
||||
# CHECK-DAG: tf.InfeedDequeue{{.+}}-> tensor<8x64xui8>
|
||||
# CHECK-DAG: tf.InfeedDequeue{{.+}}-> tensor<*xi1>
|
@ -980,6 +980,31 @@ StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx,
|
||||
}
|
||||
}
|
||||
|
||||
auto type_from_array_attr = [&node, &idx, &builder](
|
||||
absl::string_view output_shape_attr,
|
||||
absl::string_view element_type_attr) {
|
||||
auto* output_shapes = node.attrs().Find(output_shape_attr);
|
||||
auto* element_types = node.attrs().Find(element_type_attr);
|
||||
const auto& output_shape = output_shapes->list().shape(idx);
|
||||
const auto& element_type = element_types->list().type(idx);
|
||||
return ConvertToMlirTensorType(output_shape, element_type, &builder);
|
||||
};
|
||||
|
||||
if (node.type_string() == "IteratorGetNext" ||
|
||||
node.type_string() == "IteratorGetNextSync" ||
|
||||
node.type_string() == "MultiDeviceIteratorGetNextFromShard")
|
||||
return type_from_array_attr("output_shapes", "output_types");
|
||||
|
||||
if (node.type_string() == "InfeedDequeueTuple")
|
||||
return type_from_array_attr("shapes", "dtypes");
|
||||
|
||||
if (node.type_string() == "InfeedDequeue") {
|
||||
assert(idx == 0);
|
||||
const auto& output_shape = node.attrs().Find("shape")->shape();
|
||||
const auto& element_type = node.attrs().Find("dtype")->type();
|
||||
return ConvertToMlirTensorType(output_shape, element_type, &builder);
|
||||
}
|
||||
|
||||
// Returns a simple, more conservative unranked tensor type.
|
||||
auto default_type = [&]() -> StatusOr<mlir::Type> {
|
||||
mlir::Type element_type;
|
||||
|
Loading…
x
Reference in New Issue
Block a user