From 8e8c97cf6e2f66b4a9daeddb287fc94f2183f0ca Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 10 Mar 2020 21:07:50 -0700 Subject: [PATCH] [xla] add better support for variables(num_variables, index lookup by name) in xla aot/jit. PiperOrigin-RevId: 300243405 Change-Id: Iab455be5b0d3ec594b8482de2e61f5049bc4cb14 --- tensorflow/compiler/aot/codegen.cc | 25 +++++++++++-- tensorflow/compiler/aot/codegen_test.cc | 19 +++++----- tensorflow/compiler/aot/codegen_test_h.golden | 35 ++++++++++++++----- .../tf2xla/xla_compiled_cpu_function.cc | 13 ++++++- .../tf2xla/xla_compiled_cpu_function.h | 35 +++++++++++++++++-- .../tf2xla/xla_jit_compiled_cpu_function.cc | 20 +++++++++-- .../tf2xla/xla_jit_compiled_cpu_function.h | 2 ++ .../xla_jit_compiled_cpu_function_test.cc | 11 ++++++ 8 files changed, 134 insertions(+), 26 deletions(-) diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 53150e991cc..4a4fec5a386 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/types/span.h" #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -288,8 +289,8 @@ Status GenVariableMethods(const tf2xla::Config& config, } // Generates code implementing {Arg,Result}Names(), where T is one of -// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string -// literal in the array, with nullptr terminating the array. +// tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style +// string literal in the array, with nullptr terminating the array. template string GenNameToIndexCode(const T& entries, bool generate) { // No need for a static array if we're not supposed to generate the data. @@ -419,6 +420,16 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, // Generate metadata. const string arg_names_code = GenNameToIndexCode(config.feed(), opts.gen_name_to_index); + + auto variable_copy = config.variable(); + for (auto& var : variable_copy) { + if (var.name().empty()) { + var.set_name(var.node_name()); + } + } + const string variable_names_code = + GenNameToIndexCode(variable_copy, opts.gen_name_to_index); + const string result_names_code = GenNameToIndexCode(config.fetch(), opts.gen_name_to_index); const string include_xla_data_proto = @@ -507,6 +518,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = {{ARG_NUM}}; + // Number of variables for the compiled computation. + static constexpr size_t kNumVariables = {{VARIABLE_NUM}}; + // Byte size of each argument buffer. There are kNumArgs entries. static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) { return BufferInfos()[ArgIndexToBufferIndex()[index]].size(); @@ -522,8 +536,10 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { set_static_data_num_buffers(data, kNumBuffers); set_static_data_arg_index_table(data, ArgIndexToBufferIndex()); set_static_data_num_args(data, kNumArgs); + set_static_data_num_variables(data, kNumVariables); set_static_data_result_index(data, kResultIndex); set_static_data_arg_names(data, StaticArgNames()); + set_static_data_variable_names(data, StaticVariableNames()); set_static_data_result_names(data, StaticResultNames()); set_static_data_program_shape(data, StaticProgramShape()); set_static_data_hlo_profile_printer_data( @@ -626,6 +642,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { // Array of names of each positional argument, terminated by nullptr. static const char** StaticArgNames() {{ARG_NAMES_CODE}} + // Array of names of each positional variable, terminated by nullptr. + static const char** StaticVariableNames() {{VARIABLE_NAMES_CODE}} + // Array of names of each positional result, terminated by nullptr. static const char** StaticResultNames() {{RESULT_NAMES_CODE}} @@ -654,6 +673,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())}, + {"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())}, {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, @@ -673,6 +693,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))}, {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", metadata_result.program_shape_access_shim}, + {"{{VARIABLE_NAMES_CODE}}", variable_names_code}, {"{{RESULT_INDEX}}", absl::StrCat(result_index)}, {"{{RESULT_NAMES_CODE}}", result_names_code}, {"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)}, diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 6206f68faf9..babbd7fb2f5 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -156,17 +156,14 @@ static void CompareWithGoldenFile( // bazel test --test_strategy=local \ // third_party/tensorflow/compiler/aot:codegen_test const bool update_golden = false; - string golden_file_name; + string golden_file_name = + GetDataDependencyFilepath(tensorflow_relative_golden_file_name); if (update_golden) { - golden_file_name = io::JoinPath(testing::TensorFlowSrcRoot(), - tensorflow_relative_golden_file_name); TF_EXPECT_OK( WriteStringToFile(Env::Default(), golden_file_name, expected_contents)); } - golden_file_name = - GetDataDependencyFilepath(tensorflow_relative_golden_file_name); string golden_file_contents; TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name, &golden_file_contents)); @@ -220,10 +217,16 @@ TEST(CodegenTest, Golden) { {}, {BufferInfo::MakeTempBuffer(1), BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0), - BufferInfo::MakeTempBuffer(2), + BufferInfo::MakeTempBuffer(1), BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1), - BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)}, - 5, {})); + BufferInfo::MakeTempBuffer(1), + BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2), + BufferInfo::MakeTempBuffer(1), + BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3), + BufferInfo::MakeTempBuffer(1), + BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4), + BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)}, + 11, {})); compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( { diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 1669e728d1a..af58ca233f0 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -55,14 +55,17 @@ namespace bar { // ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5]) // // Memory stats: -// arg bytes total: 104 -// arg bytes aligned: 192 +// arg bytes total: 392 +// arg bytes aligned: 576 // temp bytes total: 126 -// temp bytes aligned: 320 +// temp bytes aligned: 512 class MyClass final : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. - static constexpr size_t kNumArgs = 2; + static constexpr size_t kNumArgs = 5; + + // Number of variables for the compiled computation. + static constexpr size_t kNumVariables = 3; // Byte size of each argument buffer. There are kNumArgs entries. static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) { @@ -79,8 +82,10 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { set_static_data_num_buffers(data, kNumBuffers); set_static_data_arg_index_table(data, ArgIndexToBufferIndex()); set_static_data_num_args(data, kNumArgs); + set_static_data_num_variables(data, kNumVariables); set_static_data_result_index(data, kResultIndex); set_static_data_arg_names(data, StaticArgNames()); + set_static_data_variable_names(data, StaticVariableNames()); set_static_data_result_names(data, StaticResultNames()); set_static_data_program_shape(data, StaticProgramShape()); set_static_data_hlo_profile_printer_data( @@ -295,16 +300,22 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { private: // Number of buffers for the compiled computation. - static constexpr size_t kNumBuffers = 6; + static constexpr size_t kNumBuffers = 12; static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() { static const ::xla::cpu_function_runtime::BufferInfo kBufferInfos[kNumBuffers] = { ::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), ::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}), -::xla::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}), +::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), ::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}), -::xla::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}), +::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), +::xla::cpu_function_runtime::BufferInfo({386ULL, 2ULL}), +::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), +::xla::cpu_function_runtime::BufferInfo({386ULL, 3ULL}), +::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), +::xla::cpu_function_runtime::BufferInfo({386ULL, 4ULL}), +::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), ::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL}) }; return kBufferInfos; @@ -312,13 +323,13 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { static const ::tensorflow::int32* ArgIndexToBufferIndex() { static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = { -1, 3 +1, 3, 5, 7, 9 }; return kArgIndexToBufferIndex; } // The 0-based index of the result tuple in the temporary buffers. - static constexpr size_t kResultIndex = 5; + static constexpr size_t kResultIndex = 11; // Array of names of each positional argument, terminated by nullptr. static const char** StaticArgNames() { @@ -326,6 +337,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return kNames; } + // Array of names of each positional variable, terminated by nullptr. + static const char** StaticVariableNames() { + static const char* kNames[] = {"myvar_readonly", "myvar", "myvar2", nullptr}; + return kNames; + } + // Array of names of each positional result, terminated by nullptr. static const char** StaticResultNames() { static const char* kNames[] = {"myfetch", nullptr}; diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 5420cf3e04f..3870a673e4e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -28,7 +28,9 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, buffer_infos_(static_data.buffer_infos_), arg_index_table_(static_data.arg_index_table_), num_args_(static_data.num_args_), + num_variables_(static_data.num_variables_), arg_names_(static_data.arg_names_), + variable_names_(static_data.variable_names_), result_names_(static_data.result_names_), program_shape_(static_data.program_shape_), hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) { @@ -63,6 +65,8 @@ XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { namespace { +constexpr int kNotFound = -1; + // Linear search through `names` looking for a match with `name`. Returns -1 if // the name isn't found, or is empty. // @@ -72,7 +76,6 @@ int LookupNameIndex(const string& name, const char** names) { // for AOT try the setting the tfcompile --gen_name_to_index flag. assert(names != nullptr); - constexpr int kNotFound = -1; if (name.empty()) { return kNotFound; } @@ -90,6 +93,14 @@ int XlaCompiledCpuFunction::LookupArgIndex(const string& name) const { return LookupNameIndex(name, arg_names_); } +int XlaCompiledCpuFunction::LookupVariableIndex(const string& name) const { + int index = LookupNameIndex(name, variable_names_); + if (index == kNotFound) { + return kNotFound; + } + return num_args_ - num_variables_ + index; +} + int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const { return LookupNameIndex(name, result_names_); } diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 5e452b50e71..04d9086ce4c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -76,12 +76,16 @@ class XlaCompiledCpuFunction { // There are num_args entry parameters. int64 num_args_ = 0; + // There are num_variables variables. + int64 num_variables_ = 0; + // The 0-based index of the result tuple, in the temp buffers. size_t result_index_ = 0; // [Optional] Arrays of arg and result names. These are arrays of C-style // strings, where the array is terminated by nullptr. const char** arg_names_ = nullptr; + const char** variable_names_ = nullptr; const char** result_names_ = nullptr; // [Optional] Arg and result shapes. @@ -150,6 +154,8 @@ class XlaCompiledCpuFunction { int num_args() const { return num_args_; } + int num_variables() const { return num_variables_; } + // Returns the size of entry parameter `idx`. // // There is a static version of this method on tfcompile generated subclasses @@ -212,10 +218,11 @@ class XlaCompiledCpuFunction { // ------------------------------ // Methods for extracting optional metadata. - // Returns true iff data is available for the Lookup{Arg,Result}Index methods. - // E.g. the data might not be compiled into the binary for AOT. + // Returns true iff data is available for the Lookup{Arg,Variable,Result}Index + // methods. E.g. the data might not be compiled into the binary for AOT. bool HasNameIndices() const { - return arg_names_ != nullptr && result_names_ != nullptr; + return arg_names_ != nullptr && variable_names_ != nullptr && + result_names_ != nullptr; } // Returns the 0-based index for the argument with the given `name`. @@ -226,6 +233,14 @@ class XlaCompiledCpuFunction { // Recommended usage is to capture this in a variable for re-use. int LookupArgIndex(const string& name) const; + // Returns the 0-based index for the variable with the given `name`. + // Returns -1 if the name wasn't found, or data isn't available. + // + // The index remains constant for every instance of XlaCompiledCpuFunction + // generated from the same static data, and might not be cheap to determine. + // Recommended usage is to capture this in a variable for re-use. + int LookupVariableIndex(const string& name) const; + // Returns the 0-based index for the result with the given `name`. // Returns -1 if the name wasn't found, or data isn't available. // @@ -280,6 +295,11 @@ class XlaCompiledCpuFunction { static_data->num_args_ = num_args; } + static void set_static_data_num_variables(StaticData* static_data, + int64 num_variables) { + static_data->num_variables_ = num_variables; + } + static void set_static_data_result_index(StaticData* static_data, size_t result_index) { static_data->result_index_ = result_index; @@ -290,6 +310,11 @@ class XlaCompiledCpuFunction { static_data->arg_names_ = arg_names; } + static void set_static_data_variable_names(StaticData* static_data, + const char** variable_names) { + static_data->variable_names_ = variable_names; + } + static void set_static_data_result_names(StaticData* static_data, const char** result_names) { static_data->result_names_ = result_names; @@ -334,6 +359,9 @@ class XlaCompiledCpuFunction { // The number of incoming arguments. const int32 num_args_; + // The number of incoming variables. + const int32 num_variables_; + // Backing memory for buffer_table_ and args_, the latter depending on // AllocMode. void* alloc_buffer_table_ = nullptr; @@ -346,6 +374,7 @@ class XlaCompiledCpuFunction { // Optional metadata. const char** arg_names_ = nullptr; + const char** variable_names_ = nullptr; const char** result_names_ = nullptr; const xla::ProgramShapeProto* program_shape_ = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 0392cc7d345..0deaa1ea8fb 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -49,9 +49,9 @@ xla::StatusOr ComputeResultIndex( return result_slice.index(); } -// Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold -// the actual strings in nonempty_names, and hold arrays of pointers in -// name_ptrs, terminated by a nullptr entry. +// Collect names from `entries`, where T is one of +// tf2xla::{Feed,Fetch,Variable}. We hold the actual strings in nonempty_names, +// and hold arrays of pointers in name_ptrs, terminated by a nullptr entry. template void CollectNames(const T& entries, std::vector* nonempty_names, std::vector* name_ptrs) { @@ -154,14 +154,28 @@ XlaJitCompiledCpuFunction::Compile( &jit->static_data_, jit->arg_index_table_.data()); XlaCompiledCpuFunction::set_static_data_num_args( &jit->static_data_, jit->arg_index_table_.size()); + XlaCompiledCpuFunction::set_static_data_num_variables(&jit->static_data_, + config.variable_size()); XlaCompiledCpuFunction::set_static_data_result_index(&jit->static_data_, result_index); // Optional metadata is collected and set below. CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); + + auto variable_copy = config.variable(); + for (auto& var : variable_copy) { + if (var.name().empty()) { + var.set_name(var.node_name()); + } + } + CollectNames(variable_copy, &jit->nonempty_variable_names_, + &jit->variable_names_); + CollectNames(config.fetch(), &jit->nonempty_result_names_, &jit->result_names_); XlaCompiledCpuFunction::set_static_data_arg_names(&jit->static_data_, jit->arg_names_.data()); + XlaCompiledCpuFunction::set_static_data_variable_names( + &jit->static_data_, jit->variable_names_.data()); XlaCompiledCpuFunction::set_static_data_result_names( &jit->static_data_, jit->result_names_.data()); XlaCompiledCpuFunction::set_static_data_program_shape( diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h index 11fc4571189..107968b184d 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -77,8 +77,10 @@ class XlaJitCompiledCpuFunction { // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static // data to refer to. std::vector nonempty_arg_names_; + std::vector nonempty_variable_names_; std::vector nonempty_result_names_; std::vector arg_names_; + std::vector variable_names_; std::vector result_names_; // The backing data for the program shape. The proto form of program shape is diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index f5d6b5231ac..880cb5939b6 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -210,6 +210,9 @@ TEST(XlaJitCompiledCpuFunction, Sum) { EXPECT_EQ(function.LookupResultIndex("x_name"), -1); EXPECT_EQ(function.LookupResultIndex("y_name"), -1); + EXPECT_EQ(0, function.num_variables()); + EXPECT_EQ(function.LookupVariableIndex("x"), -1); + // Check program shape. using xla::ShapeUtil; const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); @@ -252,6 +255,14 @@ TEST(XlaJitCompiledCpuFunction, SumVariable) { EXPECT_EQ(*static_cast(function.result_data(0)), 100); EXPECT_EQ(*static_cast(function.result_data(1)), 420); + // Check name to index lookups. + EXPECT_TRUE(function.HasNameIndices()); + + EXPECT_EQ(2, function.num_args()); + + EXPECT_EQ(1, function.num_variables()); + EXPECT_EQ(function.LookupVariableIndex("myvar"), 1); + // Check program shape. using xla::ShapeUtil; const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});