From 687a2d4bc5c70843e11c30c9d3e71c4e7bc14475 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Fri, 24 Apr 2020 16:22:28 -0700 Subject: [PATCH] Flip behavior for SavedModel importer to not use shape inference on import Shape inference during import will get deprecated and standalone pass used instead. PiperOrigin-RevId: 308347824 Change-Id: I8849547a5cdd157a08017b5d23b455555b8a7c01 --- .../tensorflow/tests/tf_saved_model/basic.py | 2 +- .../tests/tf_saved_model/call_to_exported.py | 6 +-- .../tf_saved_model/shapes_for_arguments.py | 9 +--- .../tf_saved_model/shapes_for_variables.py | 50 ------------------- .../tests/tf_saved_model/structured_output.py | 24 ++++----- .../mlir/tensorflow/translate/import_model.cc | 1 + 6 files changed, 18 insertions(+), 74 deletions(-) delete mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shapes_for_variables.py diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py index 78c18a17d4a..750a869682d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py @@ -48,7 +48,7 @@ class TestModule(tf.Module): # CHECK-SAME: %arg0: tensor {tf_saved_model.index_path = [0]}, # CHECK-SAME: %arg1: tensor>> {tf_saved_model.bound_input = @[[VAR]]}, # CHECK-SAME: %arg2: tensor>> {tf_saved_model.bound_input = @[[CONST]]}) -> ( - # CHECK-SAME: tensor {tf_saved_model.index_path = []}) + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = []}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["some_function"] @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) def some_function(self, x): diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/call_to_exported.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/call_to_exported.py index 658cc37a22f..d0d0c05544d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/call_to_exported.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/call_to_exported.py @@ -48,8 +48,8 @@ class TestModule(tf.Module): # CHECK-SAME: %arg0: tensor {tf_saved_model.index_path = [0]}, # CHECK-SAME: %arg1: tensor> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}} # CHECK-SAME: ) -> ( - # CHECK-SAME: tensor {tf_saved_model.index_path = [0]}, - # CHECK-SAME: tensor {tf_saved_model.index_path = [1]}) + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]}, + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]}) # CHECK-SAME: attributes{{.*}}tf_saved_model.exported_names = ["callee"] # CHECK: "tf.StatefulPartitionedCall"{{.*}}f = @[[CALLEE_INTERNAL:[a-zA-Z_0-9]+]] # @@ -57,7 +57,7 @@ class TestModule(tf.Module): # CHECK-SAME: %arg0: tensor {tf_saved_model.index_path = [0]}, # CHECK-SAME: %arg1: tensor> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}} # CHECK-SAME: ) -> ( - # CHECK-SAME: tensor {tf_saved_model.index_path = [0]}, + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]}, # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]}) # CHECK-SAME: attributes{{.*}}tf_saved_model.exported_names = ["caller"] # CHECK: "tf.StatefulPartitionedCall"{{.*}}f = @[[CALLEE_INTERNAL]] diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shapes_for_arguments.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shapes_for_arguments.py index 2a72c9bbc6b..4b10701674b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shapes_for_arguments.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shapes_for_arguments.py @@ -28,14 +28,7 @@ class TestModule(tf.Module): # Check that we get shapes annotated on function arguments. # - # Besides checking the shape on the function input argument, this test also - # checks that the shape on the input argument is propagated to the return - # value. - # We eventually want to move the shape inference to a pass separate from - # the initial import, in which case that aspect of this test doesn't make much - # sense and will be superceded by MLIR->MLIR shape inference tests. - # - # CHECK: func {{@[a-zA-Z_0-9]+}}(%arg0: tensor {{.*}}) -> (tensor {{.*}}) + # CHECK: func {{@[a-zA-Z_0-9]+}}(%arg0: tensor {{.*}}) -> (tensor<*xf32> {{.*}}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["some_function"] @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) def some_function(self, x): diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shapes_for_variables.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shapes_for_variables.py deleted file mode 100644 index 37290434f10..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shapes_for_variables.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# RUN: %p/shapes_for_variables | FileCheck %s - -# pylint: disable=missing-docstring,line-too-long -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow.compat.v2 as tf -from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common - - -class TestModule(tf.Module): - - # Check that we get shapes for variables used in the graph. - # In this case, what we are testing is that the return type of the function is - # correctly inferred, which requires understanding the shape of the variable - # (in particular, the ReadVariableOp that reads it and returns a tensor). - # - # We eventually want to move the shape inference to a pass separate from - # the initial import, in which case this test doesn't make much sense and - # will be superceded by MLIR->MLIR shape inference tests. - # - # CHECK: func {{@[a-zA-Z_0-9]+}}({{.*}}) -> (tensor {{.*}}) - # CHECK: tf_saved_model.exported_names = ["some_function"] - def __init__(self): - super(TestModule, self).__init__() - self.my_variable = tf.Variable(42.) - - @tf.function(input_signature=[]) - def some_function(self): - return self.my_variable - - -if __name__ == '__main__': - common.do_test(TestModule) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_output.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_output.py index b476df0cc25..d53a8761eb9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_output.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_output.py @@ -35,7 +35,7 @@ class TestModule(tf.Module): # Check index paths for results. # # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( - # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = []}) + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = []}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0000_single_return"] @tf.function(input_signature=[]) def f0000_single_return(self): @@ -46,8 +46,8 @@ class TestModule(tf.Module): # to returning a tuple/list. # # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( - # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]}, - # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]}) + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]}, + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0001_multiple_results_no_punctuation"] @tf.function(input_signature=[]) def f0001_multiple_results_no_punctuation(self): @@ -59,8 +59,8 @@ class TestModule(tf.Module): # of tf_saved_model users. # # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( - # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]}, - # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]}) + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]}, + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0002_multiple_results_parentheses"] @tf.function(input_signature=[]) def f0002_multiple_results_parentheses(self): @@ -72,8 +72,8 @@ class TestModule(tf.Module): # of tf_saved_model users. # # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( - # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]}, - # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]}) + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]}, + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0003_multiple_results_brackets"] @tf.function(input_signature=[]) def f0003_multiple_results_brackets(self): @@ -82,8 +82,8 @@ class TestModule(tf.Module): # Check index paths for lists. # # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( - # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0, 0]}, - # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [0, 1]}) + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0, 0]}, + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0, 1]}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0004_list_2_elements"] @tf.function(input_signature=[]) def f0004_list_2_elements(self): @@ -95,8 +95,8 @@ class TestModule(tf.Module): # path for linearization is shared, so no need to replicate that testing here. # # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( - # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = ["x"]}, - # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = ["y"]}) + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = ["x"]}, + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = ["y"]}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0005_dict_2_keys"] @tf.function(input_signature=[]) def f0005_dict_2_keys(self): @@ -111,7 +111,7 @@ class TestModule(tf.Module): # CHECK: func {{@[a-zA-Z_0-9]+}}( # CHECK-SAME: %arg0: tensor {tf_saved_model.index_path = [0]} # CHECK-SAME: ) -> ( - # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = ["x"]}) + # CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = ["x"]}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0006_multiple_return_statements"] @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) def f0006_multiple_return_statements(self, x): diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 4f40fe19e8b..9e5cd503846 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -3065,6 +3065,7 @@ StatusOr SavedModelObjectGraphImporter::Convert( GraphImportConfig specs; specs.prune_unused_nodes = true; + specs.enable_shape_inference = false; mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); std::unordered_map tf_name_to_mlir_name;