[TF saved_model_cli AOT] Unit tests for AOT compilation with fed variables.

Fixes a bug where we were freezing "readonly" variables.  Now we do not freeze any
variables that are to be fed.

PiperOrigin-RevId: 295249887
Change-Id: Ib18f07a15c07f7a893604a2cbf3d447fb7cf8f75
This commit is contained in:
Eugene Brevdo 2020-02-14 16:15:42 -08:00 committed by TensorFlower Gardener
parent a428056e32
commit 1bceffd1ae
6 changed files with 101 additions and 37 deletions

View File

@ -364,7 +364,7 @@ py_test(
saved_model_compile_aot(
name = "aot_compiled_x_plus_y",
cpp_class = "CompiledModel",
cpp_class = "XPlusY",
directory = "//tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo",
filegroups = [
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
@ -373,15 +373,40 @@ saved_model_compile_aot(
tags = ["no_rocm"],
)
saved_model_compile_aot(
name = "aot_compiled_vars_and_arithmetic_frozen",
cpp_class = "VarsAndArithmeticFrozen",
directory = "//tensorflow/cc/saved_model:testdata/VarsAndArithmeticObjectGraph",
filegroups = [
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
],
force_without_xla_support_flag = False,
tags = ["no_rocm"],
)
saved_model_compile_aot(
name = "aot_compiled_vars_and_arithmetic",
cpp_class = "VarsAndArithmetic",
directory = "//tensorflow/cc/saved_model:testdata/VarsAndArithmeticObjectGraph",
filegroups = [
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
],
force_without_xla_support_flag = False,
tags = ["no_rocm"],
variables_to_feed = "variable_x",
)
tf_cc_test(
name = "binary_using_aot_compiled_x_plus_y_test",
name = "aot_compiled_test",
srcs = if_xla_available([
"binary_using_aot_compiled_x_plus_y_test.cc",
"aot_compiled_test.cc",
]),
tags = ["no_rocm"],
deps = [
"//tensorflow/core:test_main",
] + if_xla_available([
":aot_compiled_vars_and_arithmetic",
":aot_compiled_vars_and_arithmetic_frozen",
":aot_compiled_x_plus_y",
"//tensorflow/core:test",
"//tensorflow/core/platform:logging",

View File

@ -0,0 +1,56 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/python/tools/aot_compiled_vars_and_arithmetic.h"
#include "tensorflow/python/tools/aot_compiled_vars_and_arithmetic_frozen.h"
#include "tensorflow/python/tools/aot_compiled_x_plus_y.h"
namespace tensorflow {
namespace {
TEST(AOTCompiledSavedModelTest, XPlusY) {
XPlusY model;
// Calculation is: output_0 = x + y.
*model.arg_feed_x_data() = 3.0f;
*model.arg_feed_y_data() = 4.0f;
CHECK(model.Run());
ASSERT_NEAR(model.result_fetch_output_0(), 7.0f, /*abs_error=*/1e-6f);
}
TEST(AOTCompiledSavedModelTest, VarsAndArithmetic) {
VarsAndArithmeticFrozen frozen_model;
// Calculation is:
// output_0 = [(a + variable_x) * (b + variable_y) / child_variable] + 5.0
// where {variable_x, variable_y, child_variable} = {1.0, 2.0, 3.0} when
// initialized (frozen).
*frozen_model.arg_feed_a_data() = 1.0f;
*frozen_model.arg_feed_b_data() = 2.0f;
CHECK(frozen_model.Run());
ASSERT_NEAR(frozen_model.result_fetch_output_0(),
(1.0f + 1.0f) * (2.0f + 2.0f) / 3.0f + 5.0f, /*abs_error=*/1e-6f);
VarsAndArithmetic nonfrozen_model;
*nonfrozen_model.arg_feed_a_data() = 1.0f;
*nonfrozen_model.arg_feed_b_data() = 2.0f;
// variable_x is no longer frozen. set it to 4.0;
float new_variable_x = 4.0f;
nonfrozen_model.set_var_param_variable_x_data(&new_variable_x);
CHECK(nonfrozen_model.Run());
ASSERT_NEAR(nonfrozen_model.result_fetch_output_0(),
(1.0f + 4.0f) * (2.0f + 2.0f) / 3.0f + 5.0f, /*abs_error=*/1e-6f);
}
} // namespace
} // namespace tensorflow

View File

@ -1,30 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/python/tools/aot_compiled_x_plus_y.h"
namespace tensorflow {
namespace {
TEST(AOTCompiledSavedModelTest, Run) {
CompiledModel model;
*model.arg_feed_x_data() = 3.0f;
*model.arg_feed_y_data() = 4.0f;
CHECK(model.Run());
ASSERT_NEAR(model.result_fetch_output_0(), 7.0f, /*abs_error=*/1e-6f);
}
} // namespace
} // namespace tensorflow

View File

@ -311,8 +311,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
for n in signature_def.outputs.values()
],
variable_names_blacklist=[
name for (name, node_modified) in all_variables.items()
if node_modified[1]
n.name for n, _ in variable_nodes_to_feed
],
))

View File

@ -1097,8 +1097,12 @@ def add_aot_compile_cpu_subparser(subparsers):
'Options are: empty (default; all variables are frozen, none may '
'be fed), \'all\' (all variables may be fed), or a '
'comma-delimited list of names of variables that may be fed. In '
'the last case, the non-fed variables will be frozen in the graph.')
)
'the last case, the non-fed variables will be frozen in the graph.'
'**NOTE** Any variables passed to `variables_to_feed` *must be set '
'by the user*. These variables will NOT be frozen and their '
'values will be uninitialized in the compiled object '
'(this applies to all input arguments from the signature as '
'well).'))
parser_compile.set_defaults(func=aot_compile_cpu)

View File

@ -27,6 +27,11 @@ def saved_model_compile_aot(
SavedModel and generates a cc_library with an AOT compiled model.
For extra details, see the help for saved_model_cli's aot_compile_cpu help.
**NOTE** Any variables passed to `variables_to_feed` *must be set by the
user*. These variables will NOT be frozen and their values will be
uninitialized in the compiled object (this applies to all input
arguments from the signature as well).
Example usage:
```
@ -77,6 +82,11 @@ def saved_model_compile_aot(
variables_to_feed: (optional) The names of the variables to feed, a comma
separated string, or 'all'. If empty, all variables will be frozen and none
may be fed at runtime.
**NOTE** Any variables passed to `variables_to_feed` *must be set by
the user*. These variables will NOT be frozen and their values will be
uninitialized in the compiled object (this applies to all input
arguments from the signature as well).
target_triple: The LLVM target triple to use (defaults to current build's
target architecture's triple).
force_without_xla_support_flag: Whether to compile even when