[TF saved_model_cli AOT] Unit tests for AOT compilation with fed variables.
Fixes a bug where we were freezing "readonly" variables. Now we do not freeze any variables that are to be fed. PiperOrigin-RevId: 295249887 Change-Id: Ib18f07a15c07f7a893604a2cbf3d447fb7cf8f75
This commit is contained in:
parent
a428056e32
commit
1bceffd1ae
@ -364,7 +364,7 @@ py_test(
|
||||
|
||||
saved_model_compile_aot(
|
||||
name = "aot_compiled_x_plus_y",
|
||||
cpp_class = "CompiledModel",
|
||||
cpp_class = "XPlusY",
|
||||
directory = "//tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo",
|
||||
filegroups = [
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
@ -373,15 +373,40 @@ saved_model_compile_aot(
|
||||
tags = ["no_rocm"],
|
||||
)
|
||||
|
||||
saved_model_compile_aot(
|
||||
name = "aot_compiled_vars_and_arithmetic_frozen",
|
||||
cpp_class = "VarsAndArithmeticFrozen",
|
||||
directory = "//tensorflow/cc/saved_model:testdata/VarsAndArithmeticObjectGraph",
|
||||
filegroups = [
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
force_without_xla_support_flag = False,
|
||||
tags = ["no_rocm"],
|
||||
)
|
||||
|
||||
saved_model_compile_aot(
|
||||
name = "aot_compiled_vars_and_arithmetic",
|
||||
cpp_class = "VarsAndArithmetic",
|
||||
directory = "//tensorflow/cc/saved_model:testdata/VarsAndArithmeticObjectGraph",
|
||||
filegroups = [
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
force_without_xla_support_flag = False,
|
||||
tags = ["no_rocm"],
|
||||
variables_to_feed = "variable_x",
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "binary_using_aot_compiled_x_plus_y_test",
|
||||
name = "aot_compiled_test",
|
||||
srcs = if_xla_available([
|
||||
"binary_using_aot_compiled_x_plus_y_test.cc",
|
||||
"aot_compiled_test.cc",
|
||||
]),
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
"//tensorflow/core:test_main",
|
||||
] + if_xla_available([
|
||||
":aot_compiled_vars_and_arithmetic",
|
||||
":aot_compiled_vars_and_arithmetic_frozen",
|
||||
":aot_compiled_x_plus_y",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/platform:logging",
|
||||
|
56
tensorflow/python/tools/aot_compiled_test.cc
Normal file
56
tensorflow/python/tools/aot_compiled_test.cc
Normal file
@ -0,0 +1,56 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/python/tools/aot_compiled_vars_and_arithmetic.h"
|
||||
#include "tensorflow/python/tools/aot_compiled_vars_and_arithmetic_frozen.h"
|
||||
#include "tensorflow/python/tools/aot_compiled_x_plus_y.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
TEST(AOTCompiledSavedModelTest, XPlusY) {
|
||||
XPlusY model;
|
||||
// Calculation is: output_0 = x + y.
|
||||
*model.arg_feed_x_data() = 3.0f;
|
||||
*model.arg_feed_y_data() = 4.0f;
|
||||
CHECK(model.Run());
|
||||
ASSERT_NEAR(model.result_fetch_output_0(), 7.0f, /*abs_error=*/1e-6f);
|
||||
}
|
||||
|
||||
TEST(AOTCompiledSavedModelTest, VarsAndArithmetic) {
|
||||
VarsAndArithmeticFrozen frozen_model;
|
||||
// Calculation is:
|
||||
// output_0 = [(a + variable_x) * (b + variable_y) / child_variable] + 5.0
|
||||
// where {variable_x, variable_y, child_variable} = {1.0, 2.0, 3.0} when
|
||||
// initialized (frozen).
|
||||
*frozen_model.arg_feed_a_data() = 1.0f;
|
||||
*frozen_model.arg_feed_b_data() = 2.0f;
|
||||
CHECK(frozen_model.Run());
|
||||
ASSERT_NEAR(frozen_model.result_fetch_output_0(),
|
||||
(1.0f + 1.0f) * (2.0f + 2.0f) / 3.0f + 5.0f, /*abs_error=*/1e-6f);
|
||||
|
||||
VarsAndArithmetic nonfrozen_model;
|
||||
*nonfrozen_model.arg_feed_a_data() = 1.0f;
|
||||
*nonfrozen_model.arg_feed_b_data() = 2.0f;
|
||||
// variable_x is no longer frozen. set it to 4.0;
|
||||
float new_variable_x = 4.0f;
|
||||
nonfrozen_model.set_var_param_variable_x_data(&new_variable_x);
|
||||
CHECK(nonfrozen_model.Run());
|
||||
ASSERT_NEAR(nonfrozen_model.result_fetch_output_0(),
|
||||
(1.0f + 4.0f) * (2.0f + 2.0f) / 3.0f + 5.0f, /*abs_error=*/1e-6f);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -1,30 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/python/tools/aot_compiled_x_plus_y.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
TEST(AOTCompiledSavedModelTest, Run) {
|
||||
CompiledModel model;
|
||||
*model.arg_feed_x_data() = 3.0f;
|
||||
*model.arg_feed_y_data() = 4.0f;
|
||||
CHECK(model.Run());
|
||||
ASSERT_NEAR(model.result_fetch_output_0(), 7.0f, /*abs_error=*/1e-6f);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -311,8 +311,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
|
||||
for n in signature_def.outputs.values()
|
||||
],
|
||||
variable_names_blacklist=[
|
||||
name for (name, node_modified) in all_variables.items()
|
||||
if node_modified[1]
|
||||
n.name for n, _ in variable_nodes_to_feed
|
||||
],
|
||||
))
|
||||
|
||||
|
@ -1097,8 +1097,12 @@ def add_aot_compile_cpu_subparser(subparsers):
|
||||
'Options are: empty (default; all variables are frozen, none may '
|
||||
'be fed), \'all\' (all variables may be fed), or a '
|
||||
'comma-delimited list of names of variables that may be fed. In '
|
||||
'the last case, the non-fed variables will be frozen in the graph.')
|
||||
)
|
||||
'the last case, the non-fed variables will be frozen in the graph.'
|
||||
'**NOTE** Any variables passed to `variables_to_feed` *must be set '
|
||||
'by the user*. These variables will NOT be frozen and their '
|
||||
'values will be uninitialized in the compiled object '
|
||||
'(this applies to all input arguments from the signature as '
|
||||
'well).'))
|
||||
|
||||
parser_compile.set_defaults(func=aot_compile_cpu)
|
||||
|
||||
|
@ -27,6 +27,11 @@ def saved_model_compile_aot(
|
||||
SavedModel and generates a cc_library with an AOT compiled model.
|
||||
For extra details, see the help for saved_model_cli's aot_compile_cpu help.
|
||||
|
||||
**NOTE** Any variables passed to `variables_to_feed` *must be set by the
|
||||
user*. These variables will NOT be frozen and their values will be
|
||||
uninitialized in the compiled object (this applies to all input
|
||||
arguments from the signature as well).
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
@ -77,6 +82,11 @@ def saved_model_compile_aot(
|
||||
variables_to_feed: (optional) The names of the variables to feed, a comma
|
||||
separated string, or 'all'. If empty, all variables will be frozen and none
|
||||
may be fed at runtime.
|
||||
|
||||
**NOTE** Any variables passed to `variables_to_feed` *must be set by
|
||||
the user*. These variables will NOT be frozen and their values will be
|
||||
uninitialized in the compiled object (this applies to all input
|
||||
arguments from the signature as well).
|
||||
target_triple: The LLVM target triple to use (defaults to current build's
|
||||
target architecture's triple).
|
||||
force_without_xla_support_flag: Whether to compile even when
|
||||
|
Loading…
Reference in New Issue
Block a user