Flip behavior for SavedModel importer to not use shape inference on import
Shape inference during import will get deprecated and standalone pass used instead. PiperOrigin-RevId: 308362049 Change-Id: I01259b3c881e6d4b368faab0913f51ecb8de50d7
This commit is contained in:
parent
b72a719b11
commit
947bb83ce2
@ -48,7 +48,7 @@ class TestModule(tf.Module):
|
||||
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @[[VAR]]},
|
||||
# CHECK-SAME: %arg2: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @[[CONST]]}) -> (
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = []})
|
||||
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = []})
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["some_function"]
|
||||
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
|
||||
def some_function(self, x):
|
||||
|
@ -48,8 +48,8 @@ class TestModule(tf.Module):
|
||||
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: %arg1: tensor<!tf.resource<{{.*}}>> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}}
|
||||
# CHECK-SAME: ) -> (
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]})
|
||||
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = [1]})
|
||||
# CHECK-SAME: attributes{{.*}}tf_saved_model.exported_names = ["callee"]
|
||||
# CHECK: "tf.StatefulPartitionedCall"{{.*}}f = @[[CALLEE_INTERNAL:[a-zA-Z_0-9]+]]
|
||||
#
|
||||
@ -57,7 +57,7 @@ class TestModule(tf.Module):
|
||||
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: %arg1: tensor<!tf.resource<{{.*}}>> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}}
|
||||
# CHECK-SAME: ) -> (
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]})
|
||||
# CHECK-SAME: attributes{{.*}}tf_saved_model.exported_names = ["caller"]
|
||||
# CHECK: "tf.StatefulPartitionedCall"{{.*}}f = @[[CALLEE_INTERNAL]]
|
||||
|
@ -28,7 +28,14 @@ class TestModule(tf.Module):
|
||||
|
||||
# Check that we get shapes annotated on function arguments.
|
||||
#
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}(%arg0: tensor<f32> {{.*}}) -> (tensor<*xf32> {{.*}})
|
||||
# Besides checking the shape on the function input argument, this test also
|
||||
# checks that the shape on the input argument is propagated to the return
|
||||
# value.
|
||||
# We eventually want to move the shape inference to a pass separate from
|
||||
# the initial import, in which case that aspect of this test doesn't make much
|
||||
# sense and will be superceded by MLIR->MLIR shape inference tests.
|
||||
#
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}(%arg0: tensor<f32> {{.*}}) -> (tensor<f32> {{.*}})
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["some_function"]
|
||||
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
|
||||
def some_function(self, x):
|
||||
|
@ -0,0 +1,50 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# RUN: %p/shapes_for_variables | FileCheck %s
|
||||
|
||||
# pylint: disable=missing-docstring,line-too-long
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow.compat.v2 as tf
|
||||
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common
|
||||
|
||||
|
||||
class TestModule(tf.Module):
|
||||
|
||||
# Check that we get shapes for variables used in the graph.
|
||||
# In this case, what we are testing is that the return type of the function is
|
||||
# correctly inferred, which requires understanding the shape of the variable
|
||||
# (in particular, the ReadVariableOp that reads it and returns a tensor).
|
||||
#
|
||||
# We eventually want to move the shape inference to a pass separate from
|
||||
# the initial import, in which case this test doesn't make much sense and
|
||||
# will be superceded by MLIR->MLIR shape inference tests.
|
||||
#
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}({{.*}}) -> (tensor<f32> {{.*}})
|
||||
# CHECK: tf_saved_model.exported_names = ["some_function"]
|
||||
def __init__(self):
|
||||
super(TestModule, self).__init__()
|
||||
self.my_variable = tf.Variable(42.)
|
||||
|
||||
@tf.function(input_signature=[])
|
||||
def some_function(self):
|
||||
return self.my_variable
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
common.do_test(TestModule)
|
@ -35,7 +35,7 @@ class TestModule(tf.Module):
|
||||
# Check index paths for results.
|
||||
#
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}() -> (
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = []})
|
||||
# CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = []})
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0000_single_return"]
|
||||
@tf.function(input_signature=[])
|
||||
def f0000_single_return(self):
|
||||
@ -46,8 +46,8 @@ class TestModule(tf.Module):
|
||||
# to returning a tuple/list.
|
||||
#
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}() -> (
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]})
|
||||
# CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]})
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0001_multiple_results_no_punctuation"]
|
||||
@tf.function(input_signature=[])
|
||||
def f0001_multiple_results_no_punctuation(self):
|
||||
@ -59,8 +59,8 @@ class TestModule(tf.Module):
|
||||
# of tf_saved_model users.
|
||||
#
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}() -> (
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]})
|
||||
# CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]})
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0002_multiple_results_parentheses"]
|
||||
@tf.function(input_signature=[])
|
||||
def f0002_multiple_results_parentheses(self):
|
||||
@ -72,8 +72,8 @@ class TestModule(tf.Module):
|
||||
# of tf_saved_model users.
|
||||
#
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}() -> (
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [1]})
|
||||
# CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]})
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0003_multiple_results_brackets"]
|
||||
@tf.function(input_signature=[])
|
||||
def f0003_multiple_results_brackets(self):
|
||||
@ -82,8 +82,8 @@ class TestModule(tf.Module):
|
||||
# Check index paths for lists.
|
||||
#
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}() -> (
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0, 0]},
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = [0, 1]})
|
||||
# CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0, 0]},
|
||||
# CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [0, 1]})
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0004_list_2_elements"]
|
||||
@tf.function(input_signature=[])
|
||||
def f0004_list_2_elements(self):
|
||||
@ -95,8 +95,8 @@ class TestModule(tf.Module):
|
||||
# path for linearization is shared, so no need to replicate that testing here.
|
||||
#
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}() -> (
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = ["x"]},
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = ["y"]})
|
||||
# CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = ["x"]},
|
||||
# CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = ["y"]})
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0005_dict_2_keys"]
|
||||
@tf.function(input_signature=[])
|
||||
def f0005_dict_2_keys(self):
|
||||
@ -111,7 +111,7 @@ class TestModule(tf.Module):
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}(
|
||||
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]}
|
||||
# CHECK-SAME: ) -> (
|
||||
# CHECK-SAME: tensor<*xf32> {tf_saved_model.index_path = ["x"]})
|
||||
# CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = ["x"]})
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0006_multiple_return_statements"]
|
||||
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
|
||||
def f0006_multiple_return_statements(self, x):
|
||||
|
@ -3065,7 +3065,6 @@ StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
|
||||
|
||||
GraphImportConfig specs;
|
||||
specs.prune_unused_nodes = true;
|
||||
specs.enable_shape_inference = false;
|
||||
mlir::OwningModuleRef module =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
|
||||
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
||||
|
Loading…
Reference in New Issue
Block a user