[xla] add better support for variables(num_variables, index lookup by name) in xla aot/jit.
PiperOrigin-RevId: 300243405 Change-Id: Iab455be5b0d3ec594b8482de2e61f5049bc4cb14
This commit is contained in:
parent
6857b1b67b
commit
8e8c97cf6e
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
|
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||||
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
|
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
|
||||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
|
@ -288,8 +289,8 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generates code implementing {Arg,Result}Names(), where T is one of
|
// Generates code implementing {Arg,Result}Names(), where T is one of
|
||||||
// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string
|
// tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style
|
||||||
// literal in the array, with nullptr terminating the array.
|
// string literal in the array, with nullptr terminating the array.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string GenNameToIndexCode(const T& entries, bool generate) {
|
string GenNameToIndexCode(const T& entries, bool generate) {
|
||||||
// No need for a static array if we're not supposed to generate the data.
|
// No need for a static array if we're not supposed to generate the data.
|
||||||
|
@ -419,6 +420,16 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||||
// Generate metadata.
|
// Generate metadata.
|
||||||
const string arg_names_code =
|
const string arg_names_code =
|
||||||
GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
|
GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
|
||||||
|
|
||||||
|
auto variable_copy = config.variable();
|
||||||
|
for (auto& var : variable_copy) {
|
||||||
|
if (var.name().empty()) {
|
||||||
|
var.set_name(var.node_name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const string variable_names_code =
|
||||||
|
GenNameToIndexCode(variable_copy, opts.gen_name_to_index);
|
||||||
|
|
||||||
const string result_names_code =
|
const string result_names_code =
|
||||||
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
|
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
|
||||||
const string include_xla_data_proto =
|
const string include_xla_data_proto =
|
||||||
|
@ -507,6 +518,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||||
// Number of input arguments for the compiled computation.
|
// Number of input arguments for the compiled computation.
|
||||||
static constexpr size_t kNumArgs = {{ARG_NUM}};
|
static constexpr size_t kNumArgs = {{ARG_NUM}};
|
||||||
|
|
||||||
|
// Number of variables for the compiled computation.
|
||||||
|
static constexpr size_t kNumVariables = {{VARIABLE_NUM}};
|
||||||
|
|
||||||
// Byte size of each argument buffer. There are kNumArgs entries.
|
// Byte size of each argument buffer. There are kNumArgs entries.
|
||||||
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
|
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
|
||||||
return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
|
return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
|
||||||
|
@ -522,8 +536,10 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||||
set_static_data_num_buffers(data, kNumBuffers);
|
set_static_data_num_buffers(data, kNumBuffers);
|
||||||
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
|
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
|
||||||
set_static_data_num_args(data, kNumArgs);
|
set_static_data_num_args(data, kNumArgs);
|
||||||
|
set_static_data_num_variables(data, kNumVariables);
|
||||||
set_static_data_result_index(data, kResultIndex);
|
set_static_data_result_index(data, kResultIndex);
|
||||||
set_static_data_arg_names(data, StaticArgNames());
|
set_static_data_arg_names(data, StaticArgNames());
|
||||||
|
set_static_data_variable_names(data, StaticVariableNames());
|
||||||
set_static_data_result_names(data, StaticResultNames());
|
set_static_data_result_names(data, StaticResultNames());
|
||||||
set_static_data_program_shape(data, StaticProgramShape());
|
set_static_data_program_shape(data, StaticProgramShape());
|
||||||
set_static_data_hlo_profile_printer_data(
|
set_static_data_hlo_profile_printer_data(
|
||||||
|
@ -626,6 +642,9 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||||
// Array of names of each positional argument, terminated by nullptr.
|
// Array of names of each positional argument, terminated by nullptr.
|
||||||
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
|
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
|
||||||
|
|
||||||
|
// Array of names of each positional variable, terminated by nullptr.
|
||||||
|
static const char** StaticVariableNames() {{VARIABLE_NAMES_CODE}}
|
||||||
|
|
||||||
// Array of names of each positional result, terminated by nullptr.
|
// Array of names of each positional result, terminated by nullptr.
|
||||||
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
|
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
|
||||||
|
|
||||||
|
@ -654,6 +673,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||||
{"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
|
{"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
|
||||||
{"{{ARG_NAMES_CODE}}", arg_names_code},
|
{"{{ARG_NAMES_CODE}}", arg_names_code},
|
||||||
{"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
|
{"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
|
||||||
|
{"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())},
|
||||||
{"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
|
{"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
|
||||||
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
|
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
|
||||||
{"{{CLASS}}", opts.class_name},
|
{"{{CLASS}}", opts.class_name},
|
||||||
|
@ -673,6 +693,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||||
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
|
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
|
||||||
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
|
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
|
||||||
metadata_result.program_shape_access_shim},
|
metadata_result.program_shape_access_shim},
|
||||||
|
{"{{VARIABLE_NAMES_CODE}}", variable_names_code},
|
||||||
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
|
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
|
||||||
{"{{RESULT_NAMES_CODE}}", result_names_code},
|
{"{{RESULT_NAMES_CODE}}", result_names_code},
|
||||||
{"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
|
{"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
|
||||||
|
|
|
@ -156,17 +156,14 @@ static void CompareWithGoldenFile(
|
||||||
// bazel test --test_strategy=local \
|
// bazel test --test_strategy=local \
|
||||||
// third_party/tensorflow/compiler/aot:codegen_test
|
// third_party/tensorflow/compiler/aot:codegen_test
|
||||||
const bool update_golden = false;
|
const bool update_golden = false;
|
||||||
string golden_file_name;
|
string golden_file_name =
|
||||||
|
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
|
||||||
|
|
||||||
if (update_golden) {
|
if (update_golden) {
|
||||||
golden_file_name = io::JoinPath(testing::TensorFlowSrcRoot(),
|
|
||||||
tensorflow_relative_golden_file_name);
|
|
||||||
TF_EXPECT_OK(
|
TF_EXPECT_OK(
|
||||||
WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
|
WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
|
||||||
}
|
}
|
||||||
|
|
||||||
golden_file_name =
|
|
||||||
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
|
|
||||||
string golden_file_contents;
|
string golden_file_contents;
|
||||||
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
|
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
|
||||||
&golden_file_contents));
|
&golden_file_contents));
|
||||||
|
@ -220,10 +217,16 @@ TEST(CodegenTest, Golden) {
|
||||||
{},
|
{},
|
||||||
{BufferInfo::MakeTempBuffer(1),
|
{BufferInfo::MakeTempBuffer(1),
|
||||||
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
|
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
|
||||||
BufferInfo::MakeTempBuffer(2),
|
BufferInfo::MakeTempBuffer(1),
|
||||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
|
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
|
||||||
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
|
BufferInfo::MakeTempBuffer(1),
|
||||||
5, {}));
|
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2),
|
||||||
|
BufferInfo::MakeTempBuffer(1),
|
||||||
|
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3),
|
||||||
|
BufferInfo::MakeTempBuffer(1),
|
||||||
|
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4),
|
||||||
|
BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)},
|
||||||
|
11, {}));
|
||||||
compile_result.program_shape =
|
compile_result.program_shape =
|
||||||
xla::ShapeUtil::MakeProgramShape(
|
xla::ShapeUtil::MakeProgramShape(
|
||||||
{
|
{
|
||||||
|
|
|
@ -55,14 +55,17 @@ namespace bar {
|
||||||
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
|
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
|
||||||
//
|
//
|
||||||
// Memory stats:
|
// Memory stats:
|
||||||
// arg bytes total: 104
|
// arg bytes total: 392
|
||||||
// arg bytes aligned: 192
|
// arg bytes aligned: 576
|
||||||
// temp bytes total: 126
|
// temp bytes total: 126
|
||||||
// temp bytes aligned: 320
|
// temp bytes aligned: 512
|
||||||
class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||||
public:
|
public:
|
||||||
// Number of input arguments for the compiled computation.
|
// Number of input arguments for the compiled computation.
|
||||||
static constexpr size_t kNumArgs = 2;
|
static constexpr size_t kNumArgs = 5;
|
||||||
|
|
||||||
|
// Number of variables for the compiled computation.
|
||||||
|
static constexpr size_t kNumVariables = 3;
|
||||||
|
|
||||||
// Byte size of each argument buffer. There are kNumArgs entries.
|
// Byte size of each argument buffer. There are kNumArgs entries.
|
||||||
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
|
static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
|
||||||
|
@ -79,8 +82,10 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||||
set_static_data_num_buffers(data, kNumBuffers);
|
set_static_data_num_buffers(data, kNumBuffers);
|
||||||
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
|
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
|
||||||
set_static_data_num_args(data, kNumArgs);
|
set_static_data_num_args(data, kNumArgs);
|
||||||
|
set_static_data_num_variables(data, kNumVariables);
|
||||||
set_static_data_result_index(data, kResultIndex);
|
set_static_data_result_index(data, kResultIndex);
|
||||||
set_static_data_arg_names(data, StaticArgNames());
|
set_static_data_arg_names(data, StaticArgNames());
|
||||||
|
set_static_data_variable_names(data, StaticVariableNames());
|
||||||
set_static_data_result_names(data, StaticResultNames());
|
set_static_data_result_names(data, StaticResultNames());
|
||||||
set_static_data_program_shape(data, StaticProgramShape());
|
set_static_data_program_shape(data, StaticProgramShape());
|
||||||
set_static_data_hlo_profile_printer_data(
|
set_static_data_hlo_profile_printer_data(
|
||||||
|
@ -295,16 +300,22 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Number of buffers for the compiled computation.
|
// Number of buffers for the compiled computation.
|
||||||
static constexpr size_t kNumBuffers = 6;
|
static constexpr size_t kNumBuffers = 12;
|
||||||
|
|
||||||
static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
|
static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
|
||||||
static const ::xla::cpu_function_runtime::BufferInfo
|
static const ::xla::cpu_function_runtime::BufferInfo
|
||||||
kBufferInfos[kNumBuffers] = {
|
kBufferInfos[kNumBuffers] = {
|
||||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||||
::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
|
::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
|
||||||
::xla::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
|
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||||
::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
|
::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
|
||||||
::xla::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
|
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||||
|
::xla::cpu_function_runtime::BufferInfo({386ULL, 2ULL}),
|
||||||
|
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||||
|
::xla::cpu_function_runtime::BufferInfo({386ULL, 3ULL}),
|
||||||
|
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||||
|
::xla::cpu_function_runtime::BufferInfo({386ULL, 4ULL}),
|
||||||
|
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||||
::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
|
::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
|
||||||
};
|
};
|
||||||
return kBufferInfos;
|
return kBufferInfos;
|
||||||
|
@ -312,13 +323,13 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||||
|
|
||||||
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
|
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
|
||||||
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
|
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
|
||||||
1, 3
|
1, 3, 5, 7, 9
|
||||||
};
|
};
|
||||||
return kArgIndexToBufferIndex;
|
return kArgIndexToBufferIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The 0-based index of the result tuple in the temporary buffers.
|
// The 0-based index of the result tuple in the temporary buffers.
|
||||||
static constexpr size_t kResultIndex = 5;
|
static constexpr size_t kResultIndex = 11;
|
||||||
|
|
||||||
// Array of names of each positional argument, terminated by nullptr.
|
// Array of names of each positional argument, terminated by nullptr.
|
||||||
static const char** StaticArgNames() {
|
static const char** StaticArgNames() {
|
||||||
|
@ -326,6 +337,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||||
return kNames;
|
return kNames;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Array of names of each positional variable, terminated by nullptr.
|
||||||
|
static const char** StaticVariableNames() {
|
||||||
|
static const char* kNames[] = {"myvar_readonly", "myvar", "myvar2", nullptr};
|
||||||
|
return kNames;
|
||||||
|
}
|
||||||
|
|
||||||
// Array of names of each positional result, terminated by nullptr.
|
// Array of names of each positional result, terminated by nullptr.
|
||||||
static const char** StaticResultNames() {
|
static const char** StaticResultNames() {
|
||||||
static const char* kNames[] = {"myfetch", nullptr};
|
static const char* kNames[] = {"myfetch", nullptr};
|
||||||
|
|
|
@ -28,7 +28,9 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
|
||||||
buffer_infos_(static_data.buffer_infos_),
|
buffer_infos_(static_data.buffer_infos_),
|
||||||
arg_index_table_(static_data.arg_index_table_),
|
arg_index_table_(static_data.arg_index_table_),
|
||||||
num_args_(static_data.num_args_),
|
num_args_(static_data.num_args_),
|
||||||
|
num_variables_(static_data.num_variables_),
|
||||||
arg_names_(static_data.arg_names_),
|
arg_names_(static_data.arg_names_),
|
||||||
|
variable_names_(static_data.variable_names_),
|
||||||
result_names_(static_data.result_names_),
|
result_names_(static_data.result_names_),
|
||||||
program_shape_(static_data.program_shape_),
|
program_shape_(static_data.program_shape_),
|
||||||
hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) {
|
hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) {
|
||||||
|
@ -63,6 +65,8 @@ XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr int kNotFound = -1;
|
||||||
|
|
||||||
// Linear search through `names` looking for a match with `name`. Returns -1 if
|
// Linear search through `names` looking for a match with `name`. Returns -1 if
|
||||||
// the name isn't found, or is empty.
|
// the name isn't found, or is empty.
|
||||||
//
|
//
|
||||||
|
@ -72,7 +76,6 @@ int LookupNameIndex(const string& name, const char** names) {
|
||||||
// for AOT try the setting the tfcompile --gen_name_to_index flag.
|
// for AOT try the setting the tfcompile --gen_name_to_index flag.
|
||||||
assert(names != nullptr);
|
assert(names != nullptr);
|
||||||
|
|
||||||
constexpr int kNotFound = -1;
|
|
||||||
if (name.empty()) {
|
if (name.empty()) {
|
||||||
return kNotFound;
|
return kNotFound;
|
||||||
}
|
}
|
||||||
|
@ -90,6 +93,14 @@ int XlaCompiledCpuFunction::LookupArgIndex(const string& name) const {
|
||||||
return LookupNameIndex(name, arg_names_);
|
return LookupNameIndex(name, arg_names_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int XlaCompiledCpuFunction::LookupVariableIndex(const string& name) const {
|
||||||
|
int index = LookupNameIndex(name, variable_names_);
|
||||||
|
if (index == kNotFound) {
|
||||||
|
return kNotFound;
|
||||||
|
}
|
||||||
|
return num_args_ - num_variables_ + index;
|
||||||
|
}
|
||||||
|
|
||||||
int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const {
|
int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const {
|
||||||
return LookupNameIndex(name, result_names_);
|
return LookupNameIndex(name, result_names_);
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,12 +76,16 @@ class XlaCompiledCpuFunction {
|
||||||
// There are num_args entry parameters.
|
// There are num_args entry parameters.
|
||||||
int64 num_args_ = 0;
|
int64 num_args_ = 0;
|
||||||
|
|
||||||
|
// There are num_variables variables.
|
||||||
|
int64 num_variables_ = 0;
|
||||||
|
|
||||||
// The 0-based index of the result tuple, in the temp buffers.
|
// The 0-based index of the result tuple, in the temp buffers.
|
||||||
size_t result_index_ = 0;
|
size_t result_index_ = 0;
|
||||||
|
|
||||||
// [Optional] Arrays of arg and result names. These are arrays of C-style
|
// [Optional] Arrays of arg and result names. These are arrays of C-style
|
||||||
// strings, where the array is terminated by nullptr.
|
// strings, where the array is terminated by nullptr.
|
||||||
const char** arg_names_ = nullptr;
|
const char** arg_names_ = nullptr;
|
||||||
|
const char** variable_names_ = nullptr;
|
||||||
const char** result_names_ = nullptr;
|
const char** result_names_ = nullptr;
|
||||||
|
|
||||||
// [Optional] Arg and result shapes.
|
// [Optional] Arg and result shapes.
|
||||||
|
@ -150,6 +154,8 @@ class XlaCompiledCpuFunction {
|
||||||
|
|
||||||
int num_args() const { return num_args_; }
|
int num_args() const { return num_args_; }
|
||||||
|
|
||||||
|
int num_variables() const { return num_variables_; }
|
||||||
|
|
||||||
// Returns the size of entry parameter `idx`.
|
// Returns the size of entry parameter `idx`.
|
||||||
//
|
//
|
||||||
// There is a static version of this method on tfcompile generated subclasses
|
// There is a static version of this method on tfcompile generated subclasses
|
||||||
|
@ -212,10 +218,11 @@ class XlaCompiledCpuFunction {
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
// Methods for extracting optional metadata.
|
// Methods for extracting optional metadata.
|
||||||
|
|
||||||
// Returns true iff data is available for the Lookup{Arg,Result}Index methods.
|
// Returns true iff data is available for the Lookup{Arg,Variable,Result}Index
|
||||||
// E.g. the data might not be compiled into the binary for AOT.
|
// methods. E.g. the data might not be compiled into the binary for AOT.
|
||||||
bool HasNameIndices() const {
|
bool HasNameIndices() const {
|
||||||
return arg_names_ != nullptr && result_names_ != nullptr;
|
return arg_names_ != nullptr && variable_names_ != nullptr &&
|
||||||
|
result_names_ != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the 0-based index for the argument with the given `name`.
|
// Returns the 0-based index for the argument with the given `name`.
|
||||||
|
@ -226,6 +233,14 @@ class XlaCompiledCpuFunction {
|
||||||
// Recommended usage is to capture this in a variable for re-use.
|
// Recommended usage is to capture this in a variable for re-use.
|
||||||
int LookupArgIndex(const string& name) const;
|
int LookupArgIndex(const string& name) const;
|
||||||
|
|
||||||
|
// Returns the 0-based index for the variable with the given `name`.
|
||||||
|
// Returns -1 if the name wasn't found, or data isn't available.
|
||||||
|
//
|
||||||
|
// The index remains constant for every instance of XlaCompiledCpuFunction
|
||||||
|
// generated from the same static data, and might not be cheap to determine.
|
||||||
|
// Recommended usage is to capture this in a variable for re-use.
|
||||||
|
int LookupVariableIndex(const string& name) const;
|
||||||
|
|
||||||
// Returns the 0-based index for the result with the given `name`.
|
// Returns the 0-based index for the result with the given `name`.
|
||||||
// Returns -1 if the name wasn't found, or data isn't available.
|
// Returns -1 if the name wasn't found, or data isn't available.
|
||||||
//
|
//
|
||||||
|
@ -280,6 +295,11 @@ class XlaCompiledCpuFunction {
|
||||||
static_data->num_args_ = num_args;
|
static_data->num_args_ = num_args;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void set_static_data_num_variables(StaticData* static_data,
|
||||||
|
int64 num_variables) {
|
||||||
|
static_data->num_variables_ = num_variables;
|
||||||
|
}
|
||||||
|
|
||||||
static void set_static_data_result_index(StaticData* static_data,
|
static void set_static_data_result_index(StaticData* static_data,
|
||||||
size_t result_index) {
|
size_t result_index) {
|
||||||
static_data->result_index_ = result_index;
|
static_data->result_index_ = result_index;
|
||||||
|
@ -290,6 +310,11 @@ class XlaCompiledCpuFunction {
|
||||||
static_data->arg_names_ = arg_names;
|
static_data->arg_names_ = arg_names;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void set_static_data_variable_names(StaticData* static_data,
|
||||||
|
const char** variable_names) {
|
||||||
|
static_data->variable_names_ = variable_names;
|
||||||
|
}
|
||||||
|
|
||||||
static void set_static_data_result_names(StaticData* static_data,
|
static void set_static_data_result_names(StaticData* static_data,
|
||||||
const char** result_names) {
|
const char** result_names) {
|
||||||
static_data->result_names_ = result_names;
|
static_data->result_names_ = result_names;
|
||||||
|
@ -334,6 +359,9 @@ class XlaCompiledCpuFunction {
|
||||||
// The number of incoming arguments.
|
// The number of incoming arguments.
|
||||||
const int32 num_args_;
|
const int32 num_args_;
|
||||||
|
|
||||||
|
// The number of incoming variables.
|
||||||
|
const int32 num_variables_;
|
||||||
|
|
||||||
// Backing memory for buffer_table_ and args_, the latter depending on
|
// Backing memory for buffer_table_ and args_, the latter depending on
|
||||||
// AllocMode.
|
// AllocMode.
|
||||||
void* alloc_buffer_table_ = nullptr;
|
void* alloc_buffer_table_ = nullptr;
|
||||||
|
@ -346,6 +374,7 @@ class XlaCompiledCpuFunction {
|
||||||
|
|
||||||
// Optional metadata.
|
// Optional metadata.
|
||||||
const char** arg_names_ = nullptr;
|
const char** arg_names_ = nullptr;
|
||||||
|
const char** variable_names_ = nullptr;
|
||||||
const char** result_names_ = nullptr;
|
const char** result_names_ = nullptr;
|
||||||
const xla::ProgramShapeProto* program_shape_ = nullptr;
|
const xla::ProgramShapeProto* program_shape_ = nullptr;
|
||||||
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
|
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
|
||||||
|
|
|
@ -49,9 +49,9 @@ xla::StatusOr<size_t> ComputeResultIndex(
|
||||||
return result_slice.index();
|
return result_slice.index();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold
|
// Collect names from `entries`, where T is one of
|
||||||
// the actual strings in nonempty_names, and hold arrays of pointers in
|
// tf2xla::{Feed,Fetch,Variable}. We hold the actual strings in nonempty_names,
|
||||||
// name_ptrs, terminated by a nullptr entry.
|
// and hold arrays of pointers in name_ptrs, terminated by a nullptr entry.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CollectNames(const T& entries, std::vector<string>* nonempty_names,
|
void CollectNames(const T& entries, std::vector<string>* nonempty_names,
|
||||||
std::vector<const char*>* name_ptrs) {
|
std::vector<const char*>* name_ptrs) {
|
||||||
|
@ -154,14 +154,28 @@ XlaJitCompiledCpuFunction::Compile(
|
||||||
&jit->static_data_, jit->arg_index_table_.data());
|
&jit->static_data_, jit->arg_index_table_.data());
|
||||||
XlaCompiledCpuFunction::set_static_data_num_args(
|
XlaCompiledCpuFunction::set_static_data_num_args(
|
||||||
&jit->static_data_, jit->arg_index_table_.size());
|
&jit->static_data_, jit->arg_index_table_.size());
|
||||||
|
XlaCompiledCpuFunction::set_static_data_num_variables(&jit->static_data_,
|
||||||
|
config.variable_size());
|
||||||
XlaCompiledCpuFunction::set_static_data_result_index(&jit->static_data_,
|
XlaCompiledCpuFunction::set_static_data_result_index(&jit->static_data_,
|
||||||
result_index);
|
result_index);
|
||||||
// Optional metadata is collected and set below.
|
// Optional metadata is collected and set below.
|
||||||
CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_);
|
CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_);
|
||||||
|
|
||||||
|
auto variable_copy = config.variable();
|
||||||
|
for (auto& var : variable_copy) {
|
||||||
|
if (var.name().empty()) {
|
||||||
|
var.set_name(var.node_name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CollectNames(variable_copy, &jit->nonempty_variable_names_,
|
||||||
|
&jit->variable_names_);
|
||||||
|
|
||||||
CollectNames(config.fetch(), &jit->nonempty_result_names_,
|
CollectNames(config.fetch(), &jit->nonempty_result_names_,
|
||||||
&jit->result_names_);
|
&jit->result_names_);
|
||||||
XlaCompiledCpuFunction::set_static_data_arg_names(&jit->static_data_,
|
XlaCompiledCpuFunction::set_static_data_arg_names(&jit->static_data_,
|
||||||
jit->arg_names_.data());
|
jit->arg_names_.data());
|
||||||
|
XlaCompiledCpuFunction::set_static_data_variable_names(
|
||||||
|
&jit->static_data_, jit->variable_names_.data());
|
||||||
XlaCompiledCpuFunction::set_static_data_result_names(
|
XlaCompiledCpuFunction::set_static_data_result_names(
|
||||||
&jit->static_data_, jit->result_names_.data());
|
&jit->static_data_, jit->result_names_.data());
|
||||||
XlaCompiledCpuFunction::set_static_data_program_shape(
|
XlaCompiledCpuFunction::set_static_data_program_shape(
|
||||||
|
|
|
@ -77,8 +77,10 @@ class XlaJitCompiledCpuFunction {
|
||||||
// nonempty_*_names_, and hold arrays of pointers in *_names_ for the static
|
// nonempty_*_names_, and hold arrays of pointers in *_names_ for the static
|
||||||
// data to refer to.
|
// data to refer to.
|
||||||
std::vector<string> nonempty_arg_names_;
|
std::vector<string> nonempty_arg_names_;
|
||||||
|
std::vector<string> nonempty_variable_names_;
|
||||||
std::vector<string> nonempty_result_names_;
|
std::vector<string> nonempty_result_names_;
|
||||||
std::vector<const char*> arg_names_;
|
std::vector<const char*> arg_names_;
|
||||||
|
std::vector<const char*> variable_names_;
|
||||||
std::vector<const char*> result_names_;
|
std::vector<const char*> result_names_;
|
||||||
|
|
||||||
// The backing data for the program shape. The proto form of program shape is
|
// The backing data for the program shape. The proto form of program shape is
|
||||||
|
|
|
@ -210,6 +210,9 @@ TEST(XlaJitCompiledCpuFunction, Sum) {
|
||||||
EXPECT_EQ(function.LookupResultIndex("x_name"), -1);
|
EXPECT_EQ(function.LookupResultIndex("x_name"), -1);
|
||||||
EXPECT_EQ(function.LookupResultIndex("y_name"), -1);
|
EXPECT_EQ(function.LookupResultIndex("y_name"), -1);
|
||||||
|
|
||||||
|
EXPECT_EQ(0, function.num_variables());
|
||||||
|
EXPECT_EQ(function.LookupVariableIndex("x"), -1);
|
||||||
|
|
||||||
// Check program shape.
|
// Check program shape.
|
||||||
using xla::ShapeUtil;
|
using xla::ShapeUtil;
|
||||||
const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
|
const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
|
||||||
|
@ -252,6 +255,14 @@ TEST(XlaJitCompiledCpuFunction, SumVariable) {
|
||||||
EXPECT_EQ(*static_cast<int32*>(function.result_data(0)), 100);
|
EXPECT_EQ(*static_cast<int32*>(function.result_data(0)), 100);
|
||||||
EXPECT_EQ(*static_cast<int32*>(function.result_data(1)), 420);
|
EXPECT_EQ(*static_cast<int32*>(function.result_data(1)), 420);
|
||||||
|
|
||||||
|
// Check name to index lookups.
|
||||||
|
EXPECT_TRUE(function.HasNameIndices());
|
||||||
|
|
||||||
|
EXPECT_EQ(2, function.num_args());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, function.num_variables());
|
||||||
|
EXPECT_EQ(function.LookupVariableIndex("myvar"), 1);
|
||||||
|
|
||||||
// Check program shape.
|
// Check program shape.
|
||||||
using xla::ShapeUtil;
|
using xla::ShapeUtil;
|
||||||
const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
|
const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
|
||||||
|
|
Loading…
Reference in New Issue