From 38ca061ed3c2c6e9e32dec0654c77ca890a1d734 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Fri, 20 Mar 2020 12:58:57 -0700 Subject: [PATCH] Fix a check in tfcompile codegen. Previously, tfcompile expects the compiler to translate each resource variable into a function argument. The MLIR bridge doesn't generated a function argument for an unused resource variable. Modify an existing test to test the situation. PiperOrigin-RevId: 302083482 Change-Id: I08301e594422f655b8d4ba4bb66d69103764cc7f --- tensorflow/compiler/aot/codegen.cc | 4 +++- tensorflow/compiler/aot/tests/make_test_graphs.py | 1 + .../aot/tests/test_graph_tfvariable_readonly.config.pbtxt | 8 ++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 4a4fec5a386..c9a36b88795 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -170,7 +170,9 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShapeProto& ps, const CompileResult& compile_result, string* methods) { size_t num_args = ps.parameters_size(); - if (config.feed_size() + config.variable_size() != num_args) { + // feed_size() + variable_size() is the maximum number of args as an + // implementation may not create an argument for an unused variable. + if (config.feed_size() + config.variable_size() < num_args) { return errors::InvalidArgument( "mismatch between feed_size(", config.feed_size(), ")+variable_size(", config.variable_size(), ") and num_args(", num_args, ")"); diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 629239d6e4a..532d64c5a3e 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -157,6 +157,7 @@ def tftop_k(_): def tfvariable_readonly(_): x = variables.Variable(1000.0, name='x') + unused_y = variables.Variable(1000.0, name='y') old_x = x.value() with ops.control_dependencies([old_x]): new_value = math_ops.add(old_x, 42.0) diff --git a/tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.config.pbtxt index b615b8f1522..dd2d0399451 100644 --- a/tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.config.pbtxt +++ b/tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.config.pbtxt @@ -10,3 +10,11 @@ variable { type: DT_FLOAT readonly: true } + +variable { + node_name: "y" + shape { + } + type: DT_FLOAT + readonly: true +}