Add named size and count methods for arg, result and var methods to AOT models.

PiperOrigin-RevId: 310375046
Change-Id: I3fd5c7fbdcfe141449a0a4d6827f6e5fe14b4e0b
This commit is contained in:
A. Unique TensorFlower 2020-05-07 09:24:28 -07:00 committed by TensorFlower Gardener
parent 0cc3e612bd
commit 96f4a930db
2 changed files with 69 additions and 0 deletions
tensorflow/compiler/aot

View File

@ -131,6 +131,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
std::vector<string> dim_vars;
string dim_sizes, indices;
int count = 1;
if (shape.rank() == 0 ||
(shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
dim_sizes = "[1]";
@ -140,6 +141,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
dim_vars.push_back(absl::StrCat("size_t dim", dim));
dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
indices += absl::StrCat("[dim", dim, "]");
count *= shape.dimensions(dim);
}
}
rewrites->push_back({"{{I}}", absl::StrCat(i)});
@ -147,6 +149,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
rewrites->push_back({"{{INDICES}}", indices});
rewrites->push_back({"{{COUNT}}", absl::StrCat(count)});
return Status::OK();
}
@ -199,6 +202,12 @@ Status GenArgMethods(const tf2xla::Config& config,
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
arg_data({{I}}))){{INDICES}};
}
int arg{{NAME}}_size() const {
return {{COUNT}} * sizeof({{TYPE}});
}
int arg{{NAME}}_count() const {
return {{COUNT}};
}
)";
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
if (!config.feed(i).name().empty()) {
@ -246,6 +255,12 @@ Status GenResultMethods(const tf2xla::Config& config,
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
result_data({{I}}))){{INDICES}};
}
int result{{NAME}}_size() const {
return {{COUNT}} * sizeof({{TYPE}});
}
int result{{NAME}}_count() const {
return {{COUNT}};
}
)";
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
if (!config.fetch(i).name().empty()) {
@ -281,6 +296,12 @@ Status GenVariableMethods(const tf2xla::Config& config,
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
arg_data({{I}}))){{INDICES}};
}
int var_{{NAME}}_size() const {
return {{COUNT}} * sizeof({{TYPE}});
}
int var_{{NAME}}_count() const {
return {{COUNT}};
}
)";
const tf2xla::Variable& var = config.variable(i - config.feed_size());
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");

View File

@ -138,6 +138,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const float(*)[1][2]>(
arg_data(0)))[dim0][dim1];
}
int arg0_size() const {
return 2 * sizeof(float);
}
int arg0_count() const {
return 2;
}
void set_arg_myfeed_data(const void* data) {
set_arg_data(0, data);
@ -156,6 +162,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const float(*)[1][2]>(
arg_data(0)))[dim0][dim1];
}
int arg_myfeed_size() const {
return 2 * sizeof(float);
}
int arg_myfeed_count() const {
return 2;
}
void set_arg1_data(const void* data) {
set_arg_data(1, data);
@ -174,6 +186,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const tensorflow::int64(*)[3][4]>(
arg_data(1)))[dim0][dim1];
}
int arg1_size() const {
return 12 * sizeof(tensorflow::int64);
}
int arg1_count() const {
return 12;
}
// Result methods for managing output buffers. Buffers are in row-major order.
// Must only be called after a successful Run call. There is a set of methods
@ -204,6 +222,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
result_data(0)))[dim0][dim1];
}
int result0_size() const {
return 30 * sizeof(tensorflow::uint32);
}
int result0_count() const {
return 30;
}
tensorflow::uint32* result_myfetch_data() {
return static_cast<tensorflow::uint32*>(result_data(0));
@ -219,6 +243,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
result_data(0)))[dim0][dim1];
}
int result_myfetch_size() const {
return 30 * sizeof(tensorflow::uint32);
}
int result_myfetch_count() const {
return 30;
}
// Methods for managing variable buffers. Buffers are in row-major order.
//
@ -261,6 +291,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const float(*)[1]>(
arg_data(2)))[0];
}
int var_myvar_readonly_size() const {
return 1 * sizeof(float);
}
int var_myvar_readonly_count() const {
return 1;
}
void set_var_myvar_data(float* data) {
set_arg_data(3, data);
@ -279,6 +315,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const float(*)[1]>(
arg_data(3)))[0];
}
int var_myvar_size() const {
return 1 * sizeof(float);
}
int var_myvar_count() const {
return 1;
}
void set_var_myvar2_data(tensorflow::int32* data) {
set_arg_data(4, data);
@ -297,6 +339,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return (*static_cast<const tensorflow::int32(*)[5]>(
arg_data(4)))[dim0];
}
int var_myvar2_size() const {
return 5 * sizeof(tensorflow::int32);
}
int var_myvar2_count() const {
return 5;
}
private:
// Number of buffers for the compiled computation.