From 26aa3d515d98ba59d52c80f523beb226913b1d07 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 25 Mar 2019 03:26:05 -0700 Subject: [PATCH] [XLA:AOT] Add a readonly flag to allow resource variables that are never assigned It's a bit unfortunate that we have to leave this decision to the user, but I don't see a better way of handling variables that are never assigned to. XLA will just prune away the output for them so we have to account for them early. The API follows the in-place API, but we make the setter const and don't generate getters. PiperOrigin-RevId: 240107733 --- tensorflow/compiler/aot/codegen.cc | 15 +++++++-- tensorflow/compiler/aot/codegen_test.cc | 16 +++++++--- tensorflow/compiler/aot/codegen_test_h.golden | 12 ++++--- tensorflow/compiler/aot/codegen_test_o.golden | Bin 800 -> 816 bytes .../compiler/aot/tests/make_test_graphs.py | 3 +- ...tfvariable_sequential_updates.config.pbtxt | 6 ++++ .../compiler/aot/tests/tfcompile_test.cc | 10 +++--- tensorflow/compiler/tf2xla/tf2xla.cc | 30 +++++++++++++++--- tensorflow/compiler/tf2xla/tf2xla.proto | 19 +++++++---- 9 files changed, 82 insertions(+), 29 deletions(-) diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 2355fad8802..fe919d34919 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -213,7 +213,11 @@ Status GenResultMethods(const tf2xla::Config& config, return errors::Internal("codegen requires the XLA result to be a tuple"); } size_t num_results = ps.result().tuple_shapes_size(); - if (config.fetch_size() + config.variable_size() != num_results) { + int readonly_variables = absl::c_count_if( + config.variable(), + [](const tf2xla::Variable& var) { return var.readonly(); }); + if (config.fetch_size() + config.variable_size() - readonly_variables != + num_results) { return errors::InvalidArgument("mismatch between fetch_size(", config.fetch_size(), ")+variable_size(", config.variable_size(), ") and tuple_size(", @@ -256,15 +260,17 @@ Status GenVariableMethods(const tf2xla::Config& config, TF_RETURN_IF_ERROR( AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites)); const string code = R"( - void set_var_{{NAME}}_data({{TYPE}}* data) { + void set_var_{{NAME}}_data({{MAYBE_CONST}}{{TYPE}}* data) { set_arg_data({{I}}, data); } )"; const tf2xla::Variable& var = config.variable(i - config.feed_size()); + rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : ""); *methods += RewriteWithName( var.name().empty() ? var.node_name() : var.name(), code, rewrites); } size_t num_results = ps.result().tuple_shapes_size(); + int variable_num = -1; for (int i = config.fetch_size(); i < num_results; ++i) { std::vector> rewrites; TF_RETURN_IF_ERROR(AddRewritesForShape( @@ -285,7 +291,10 @@ Status GenVariableMethods(const tf2xla::Config& config, result_data({{I}}))){{INDICES}}; } )"; - const tf2xla::Variable& var = config.variable(i - config.fetch_size()); + do { + ++variable_num; + } while (config.variable(variable_num).readonly()); + const tf2xla::Variable& var = config.variable(variable_num); *methods += RewriteWithName( var.name().empty() ? var.node_name() : var.name(), code, rewrites); } diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 5580e55b691..46f1cd57dde 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -175,14 +175,19 @@ TEST(CodegenTest, Golden) { fetch->mutable_id()->set_node_name("fetch0"); fetch->set_name("myfetch"); tf2xla::Variable* variable = config.add_variable(); - variable->set_node_name("myvar"); + variable->set_node_name("myvar_readonly"); variable->mutable_shape()->add_dim()->set_size(1); variable->set_type(DT_FLOAT); + variable->set_readonly(true); tf2xla::Variable* variable2 = config.add_variable(); - variable2->set_node_name("my/var"); - variable2->set_name("myvar2"); - variable2->mutable_shape()->add_dim()->set_size(5); - variable2->set_type(DT_INT32); + variable2->set_node_name("myvar"); + variable2->mutable_shape()->add_dim()->set_size(1); + variable2->set_type(DT_FLOAT); + tf2xla::Variable* variable3 = config.add_variable(); + variable3->set_node_name("my/var"); + variable3->set_name("myvar2"); + variable3->mutable_shape()->add_dim()->set_size(5); + variable3->set_type(DT_INT32); CompileResult compile_result; compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult( {}, @@ -198,6 +203,7 @@ TEST(CodegenTest, Golden) { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), xla::ShapeUtil::MakeShape(xla::F32, {1}), + xla::ShapeUtil::MakeShape(xla::F32, {1}), xla::ShapeUtil::MakeShape(xla::S32, {5}), }, xla::ShapeUtil::MakeTupleShape({ diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 8591df53877..1ffa39b0e3b 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -52,7 +52,7 @@ namespace bar { // is guaranteed that no thread may call a non-const method. // // The logical function signature is: -// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5]) +// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5]) // // Memory stats: // arg bytes total: 104 @@ -228,14 +228,18 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { // with dim indices specifying which value. No bounds checking is performed // on dim indices. - void set_var_myvar_data(float* data) { + void set_var_myvar_readonly_data(const float* data) { set_arg_data(2, data); } - void set_var_myvar2_data(tensorflow::int32* data) { + void set_var_myvar_data(float* data) { set_arg_data(3, data); } + void set_var_myvar2_data(tensorflow::int32* data) { + set_arg_data(4, data); + } + float* var_myvar_data() { return static_cast(result_data(1)); } @@ -309,7 +313,7 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { static const xla::ProgramShapeProto* StaticProgramShape() { static const xla::ProgramShapeProto* kShape = []() { xla::ProgramShapeProto* proto = new xla::ProgramShapeProto; - proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 132); + proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 149); return proto; }(); return kShape; diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden index 2884597abcf29583e6192296b0e4ce6825d7c01a..38c75d1fb60dba018c0f7d412b8bc8eb96a0e3ee 100644 GIT binary patch delta 51 zcmZ3$wt;Pe2IGc_n!yvZOC~n!yu$lPA8{WNeu%$@rd8U@|9DJCHq*$(-@X graph, arg.initialized = true; xla_args.push_back(std::move(arg)); - // We want to alias the input and output of the variable, so the updates are - // carried out in-place. - xla_aliases.push_back({/*output_index=*/{output_num}, - /*param_number=*/input_num, /*param_index=*/{}}); + if (!variable.readonly()) { + // We want to alias the input and output of the variable, so the updates + // are carried out in-place. + xla_aliases.push_back({/*output_index=*/{output_num}, + /*param_number=*/input_num, /*param_index=*/{}}); + ++output_num; + } ++input_num; - ++output_num; } // Compile the graph into an XLA computation. @@ -324,6 +326,24 @@ Status ConvertGraphToXla(std::unique_ptr graph, " constant results. The configuration of " "the output args (i.e. fetch ids) is probably wrong."); } + { + // Verify that the readonly bits on variables are set correctly by the user. + std::vector updated_inputs(xla_args.size()); + for (const XlaCompiler::ResourceUpdate& update : result.resource_updates) { + updated_inputs[update.input_index] = true; + } + int64 input_index = xla_args.size() - config.variable_size(); + for (const tf2xla::Variable& variable : config.variable()) { + if (variable.readonly() == updated_inputs[input_index]) { + return errors::InvalidArgument( + "Variable \"", variable.node_name(), "\" is marked as ", + variable.readonly() ? "" : "not ", "readonly, but is ", + updated_inputs[input_index] ? "" : "not ", + "modified by the computation."); + } + ++input_index; + } + } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla.proto b/tensorflow/compiler/tf2xla/tf2xla.proto index 5627af7452b..f47608155ec 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.proto +++ b/tensorflow/compiler/tf2xla/tf2xla.proto @@ -1,14 +1,15 @@ syntax = "proto3"; package tensorflow.tf2xla; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + option cc_enable_arenas = true; option java_outer_classname = "Tf2XlaProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.tf2xla"; -import "tensorflow/core/framework/tensor_shape.proto"; -import "tensorflow/core/framework/types.proto"; - // TensorId identifies a tensor in a TensorFlow graph, by specifying the output // index of a particular node in the graph. If the output of the named node // feeds into other node(s), this corresponds to one or more edges. Otherwise @@ -16,7 +17,7 @@ import "tensorflow/core/framework/types.proto"; message TensorId { string node_name = 1; int64 output_index = 2; -}; +} // Feed represents a single feed tensor in the graph, which corresponds to an // input argument for the generated computation. @@ -30,14 +31,14 @@ message Feed { // not linked into the binary, then the type cannot be inferred from the node; // in this case, the type should be set here. DataType type = 4; -}; +} // Fetch represents a single fetch tensor in the graph, which corresponds to an // output argument for the generated computation. message Fetch { TensorId id = 1; string name = 2; // Optional name for generated code. -}; +} // Variable represents a resource variable with the given name, shape and type. message Variable { @@ -46,6 +47,10 @@ message Variable { 2; // Optional name for generated code. If empty, node_name will be used. TensorShapeProto shape = 3; DataType type = 4; + + // Flag for variables that are never assigned. Assigments to a read-only + // variable or unassigned variables that are not read-only are invalid. + bool readonly = 5; } // Config represents configuration information for tf2xla conversion. @@ -58,4 +63,4 @@ message Config { repeated Fetch fetch = 2; // Each variable is a named input and output of the generated computation. repeated Variable variable = 3; -}; +}