Flip behavior for SavedModel importer to not use shape inference on import

Shape inference during import will get deprecated and standalone pass used instead.

PiperOrigin-RevId: 308362049
Change-Id: I01259b3c881e6d4b368faab0913f51ecb8de50d7
This commit is contained in:
A. Unique TensorFlower 2020-04-24 18:11:24 -07:00 committed by TensorFlower Gardener
parent b72a719b11
commit 947bb83ce2
6 changed files with 74 additions and 18 deletions

View File

@ -48,7 +48,7 @@ class TestModule(tf.Module):
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @[[VAR]]},
# CHECK-SAME: %arg2: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @[[CONST]]}) -> (
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = []})
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = []})
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["some_function"]
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
def some_function(self, x):

View File

@ -48,8 +48,8 @@ class TestModule(tf.Module):
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: %arg1: tensor<!tf.resource<{{.*}}>> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}}
# CHECK-SAME: ) -> (
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]})
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = [1]})
# CHECK-SAME: attributes{{.*}}tf_saved_model.exported_names = ["callee"]
# CHECK: "tf.StatefulPartitionedCall"{{.*}}f = @[[CALLEE_INTERNAL:[a-zA-Z_0-9]+]]
#
@ -57,7 +57,7 @@ class TestModule(tf.Module):
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: %arg1: tensor<!tf.resource<{{.*}}>> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}}
# CHECK-SAME: ) -> (
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]})
# CHECK-SAME: attributes{{.*}}tf_saved_model.exported_names = ["caller"]
# CHECK: "tf.StatefulPartitionedCall"{{.*}}f = @[[CALLEE_INTERNAL]]

View File

@ -28,7 +28,14 @@ class TestModule(tf.Module):
# Check that we get shapes annotated on function arguments.
#
# CHECK: func {{@[a-zA-Z_0-9]+}}(%arg0: tensor<f32> {{.*}}) -> (tensor<*xf32> {{.*}})
# Besides checking the shape on the function input argument, this test also
# checks that the shape on the input argument is propagated to the return
# value.
# We eventually want to move the shape inference to a pass separate from
# the initial import, in which case that aspect of this test doesn't make much
# sense and will be superceded by MLIR->MLIR shape inference tests.
#
# CHECK: func {{@[a-zA-Z_0-9]+}}(%arg0: tensor<f32> {{.*}}) -> (tensor<f32> {{.*}})
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["some_function"]
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
def some_function(self, x):

View File

@ -0,0 +1,50 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# RUN: %p/shapes_for_variables | FileCheck %s
# pylint: disable=missing-docstring,line-too-long
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common
class TestModule(tf.Module):
# Check that we get shapes for variables used in the graph.
# In this case, what we are testing is that the return type of the function is
# correctly inferred, which requires understanding the shape of the variable
# (in particular, the ReadVariableOp that reads it and returns a tensor).
#
# We eventually want to move the shape inference to a pass separate from
# the initial import, in which case this test doesn't make much sense and
# will be superceded by MLIR->MLIR shape inference tests.
#
# CHECK: func {{@[a-zA-Z_0-9]+}}({{.*}}) -> (tensor<f32> {{.*}})
# CHECK: tf_saved_model.exported_names = ["some_function"]
def __init__(self):
super(TestModule, self).__init__()
self.my_variable = tf.Variable(42.)
@tf.function(input_signature=[])
def some_function(self):
return self.my_variable
if __name__ == '__main__':
common.do_test(TestModule)

View File

@ -35,7 +35,7 @@ class TestModule(tf.Module):
# Check index paths for results.
#
# CHECK: func {{@[a-zA-Z_0-9]+}}() -> (
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = []})
# CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = []})
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0000_single_return"]
@tf.function(input_signature=[])
def f0000_single_return(self):
@ -46,8 +46,8 @@ class TestModule(tf.Module):
# to returning a tuple/list.
#
# CHECK: func {{@[a-zA-Z_0-9]+}}() -> (
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]})
# CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]})
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0001_multiple_results_no_punctuation"]
@tf.function(input_signature=[])
def f0001_multiple_results_no_punctuation(self):
@ -59,8 +59,8 @@ class TestModule(tf.Module):
# of tf_saved_model users.
#
# CHECK: func {{@[a-zA-Z_0-9]+}}() -> (
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]})
# CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]})
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0002_multiple_results_parentheses"]
@tf.function(input_signature=[])
def f0002_multiple_results_parentheses(self):
@ -72,8 +72,8 @@ class TestModule(tf.Module):
# of tf_saved_model users.
#
# CHECK: func {{@[a-zA-Z_0-9]+}}() -> (
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]})
# CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]},
# CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]})
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0003_multiple_results_brackets"]
@tf.function(input_signature=[])
def f0003_multiple_results_brackets(self):
@ -82,8 +82,8 @@ class TestModule(tf.Module):
# Check index paths for lists.
#
# CHECK: func {{@[a-zA-Z_0-9]+}}() -> (
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0, 0]},
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0, 1]})
# CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0, 0]},
# CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [0, 1]})
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0004_list_2_elements"]
@tf.function(input_signature=[])
def f0004_list_2_elements(self):
@ -95,8 +95,8 @@ class TestModule(tf.Module):
# path for linearization is shared, so no need to replicate that testing here.
#
# CHECK: func {{@[a-zA-Z_0-9]+}}() -> (
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = ["x"]},
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = ["y"]})
# CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = ["x"]},
# CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = ["y"]})
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0005_dict_2_keys"]
@tf.function(input_signature=[])
def f0005_dict_2_keys(self):
@ -111,7 +111,7 @@ class TestModule(tf.Module):
# CHECK: func {{@[a-zA-Z_0-9]+}}(
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]}
# CHECK-SAME: ) -> (
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = ["x"]})
# CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = ["x"]})
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0006_multiple_return_statements"]
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
def f0006_multiple_return_statements(self, x):

View File

@ -3065,7 +3065,6 @@ StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
GraphImportConfig specs;
specs.prune_unused_nodes = true;
specs.enable_shape_inference = false;
mlir::OwningModuleRef module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;