Add named size and count methods for arg, result and var methods to AOT models.
PiperOrigin-RevId: 310375046 Change-Id: I3fd5c7fbdcfe141449a0a4d6827f6e5fe14b4e0b
This commit is contained in:
parent
0cc3e612bd
commit
96f4a930db
@ -131,6 +131,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
|||||||
TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
|
TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
|
||||||
std::vector<string> dim_vars;
|
std::vector<string> dim_vars;
|
||||||
string dim_sizes, indices;
|
string dim_sizes, indices;
|
||||||
|
int count = 1;
|
||||||
if (shape.rank() == 0 ||
|
if (shape.rank() == 0 ||
|
||||||
(shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
|
(shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
|
||||||
dim_sizes = "[1]";
|
dim_sizes = "[1]";
|
||||||
@ -140,6 +141,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
|||||||
dim_vars.push_back(absl::StrCat("size_t dim", dim));
|
dim_vars.push_back(absl::StrCat("size_t dim", dim));
|
||||||
dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
|
dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
|
||||||
indices += absl::StrCat("[dim", dim, "]");
|
indices += absl::StrCat("[dim", dim, "]");
|
||||||
|
count *= shape.dimensions(dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rewrites->push_back({"{{I}}", absl::StrCat(i)});
|
rewrites->push_back({"{{I}}", absl::StrCat(i)});
|
||||||
@ -147,6 +149,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
|||||||
rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
|
rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
|
||||||
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
|
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
|
||||||
rewrites->push_back({"{{INDICES}}", indices});
|
rewrites->push_back({"{{INDICES}}", indices});
|
||||||
|
rewrites->push_back({"{{COUNT}}", absl::StrCat(count)});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,6 +202,12 @@ Status GenArgMethods(const tf2xla::Config& config,
|
|||||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||||
arg_data({{I}}))){{INDICES}};
|
arg_data({{I}}))){{INDICES}};
|
||||||
}
|
}
|
||||||
|
int arg{{NAME}}_size() const {
|
||||||
|
return {{COUNT}} * sizeof({{TYPE}});
|
||||||
|
}
|
||||||
|
int arg{{NAME}}_count() const {
|
||||||
|
return {{COUNT}};
|
||||||
|
}
|
||||||
)";
|
)";
|
||||||
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
||||||
if (!config.feed(i).name().empty()) {
|
if (!config.feed(i).name().empty()) {
|
||||||
@ -246,6 +255,12 @@ Status GenResultMethods(const tf2xla::Config& config,
|
|||||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||||
result_data({{I}}))){{INDICES}};
|
result_data({{I}}))){{INDICES}};
|
||||||
}
|
}
|
||||||
|
int result{{NAME}}_size() const {
|
||||||
|
return {{COUNT}} * sizeof({{TYPE}});
|
||||||
|
}
|
||||||
|
int result{{NAME}}_count() const {
|
||||||
|
return {{COUNT}};
|
||||||
|
}
|
||||||
)";
|
)";
|
||||||
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
||||||
if (!config.fetch(i).name().empty()) {
|
if (!config.fetch(i).name().empty()) {
|
||||||
@ -281,6 +296,12 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
|||||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||||
arg_data({{I}}))){{INDICES}};
|
arg_data({{I}}))){{INDICES}};
|
||||||
}
|
}
|
||||||
|
int var_{{NAME}}_size() const {
|
||||||
|
return {{COUNT}} * sizeof({{TYPE}});
|
||||||
|
}
|
||||||
|
int var_{{NAME}}_count() const {
|
||||||
|
return {{COUNT}};
|
||||||
|
}
|
||||||
)";
|
)";
|
||||||
const tf2xla::Variable& var = config.variable(i - config.feed_size());
|
const tf2xla::Variable& var = config.variable(i - config.feed_size());
|
||||||
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
|
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
|
||||||
|
@ -138,6 +138,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const float(*)[1][2]>(
|
return (*static_cast<const float(*)[1][2]>(
|
||||||
arg_data(0)))[dim0][dim1];
|
arg_data(0)))[dim0][dim1];
|
||||||
}
|
}
|
||||||
|
int arg0_size() const {
|
||||||
|
return 2 * sizeof(float);
|
||||||
|
}
|
||||||
|
int arg0_count() const {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
void set_arg_myfeed_data(const void* data) {
|
void set_arg_myfeed_data(const void* data) {
|
||||||
set_arg_data(0, data);
|
set_arg_data(0, data);
|
||||||
@ -156,6 +162,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const float(*)[1][2]>(
|
return (*static_cast<const float(*)[1][2]>(
|
||||||
arg_data(0)))[dim0][dim1];
|
arg_data(0)))[dim0][dim1];
|
||||||
}
|
}
|
||||||
|
int arg_myfeed_size() const {
|
||||||
|
return 2 * sizeof(float);
|
||||||
|
}
|
||||||
|
int arg_myfeed_count() const {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
void set_arg1_data(const void* data) {
|
void set_arg1_data(const void* data) {
|
||||||
set_arg_data(1, data);
|
set_arg_data(1, data);
|
||||||
@ -174,6 +186,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const tensorflow::int64(*)[3][4]>(
|
return (*static_cast<const tensorflow::int64(*)[3][4]>(
|
||||||
arg_data(1)))[dim0][dim1];
|
arg_data(1)))[dim0][dim1];
|
||||||
}
|
}
|
||||||
|
int arg1_size() const {
|
||||||
|
return 12 * sizeof(tensorflow::int64);
|
||||||
|
}
|
||||||
|
int arg1_count() const {
|
||||||
|
return 12;
|
||||||
|
}
|
||||||
|
|
||||||
// Result methods for managing output buffers. Buffers are in row-major order.
|
// Result methods for managing output buffers. Buffers are in row-major order.
|
||||||
// Must only be called after a successful Run call. There is a set of methods
|
// Must only be called after a successful Run call. There is a set of methods
|
||||||
@ -204,6 +222,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
||||||
result_data(0)))[dim0][dim1];
|
result_data(0)))[dim0][dim1];
|
||||||
}
|
}
|
||||||
|
int result0_size() const {
|
||||||
|
return 30 * sizeof(tensorflow::uint32);
|
||||||
|
}
|
||||||
|
int result0_count() const {
|
||||||
|
return 30;
|
||||||
|
}
|
||||||
|
|
||||||
tensorflow::uint32* result_myfetch_data() {
|
tensorflow::uint32* result_myfetch_data() {
|
||||||
return static_cast<tensorflow::uint32*>(result_data(0));
|
return static_cast<tensorflow::uint32*>(result_data(0));
|
||||||
@ -219,6 +243,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
||||||
result_data(0)))[dim0][dim1];
|
result_data(0)))[dim0][dim1];
|
||||||
}
|
}
|
||||||
|
int result_myfetch_size() const {
|
||||||
|
return 30 * sizeof(tensorflow::uint32);
|
||||||
|
}
|
||||||
|
int result_myfetch_count() const {
|
||||||
|
return 30;
|
||||||
|
}
|
||||||
|
|
||||||
// Methods for managing variable buffers. Buffers are in row-major order.
|
// Methods for managing variable buffers. Buffers are in row-major order.
|
||||||
//
|
//
|
||||||
@ -261,6 +291,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const float(*)[1]>(
|
return (*static_cast<const float(*)[1]>(
|
||||||
arg_data(2)))[0];
|
arg_data(2)))[0];
|
||||||
}
|
}
|
||||||
|
int var_myvar_readonly_size() const {
|
||||||
|
return 1 * sizeof(float);
|
||||||
|
}
|
||||||
|
int var_myvar_readonly_count() const {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
void set_var_myvar_data(float* data) {
|
void set_var_myvar_data(float* data) {
|
||||||
set_arg_data(3, data);
|
set_arg_data(3, data);
|
||||||
@ -279,6 +315,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const float(*)[1]>(
|
return (*static_cast<const float(*)[1]>(
|
||||||
arg_data(3)))[0];
|
arg_data(3)))[0];
|
||||||
}
|
}
|
||||||
|
int var_myvar_size() const {
|
||||||
|
return 1 * sizeof(float);
|
||||||
|
}
|
||||||
|
int var_myvar_count() const {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
void set_var_myvar2_data(tensorflow::int32* data) {
|
void set_var_myvar2_data(tensorflow::int32* data) {
|
||||||
set_arg_data(4, data);
|
set_arg_data(4, data);
|
||||||
@ -297,6 +339,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const tensorflow::int32(*)[5]>(
|
return (*static_cast<const tensorflow::int32(*)[5]>(
|
||||||
arg_data(4)))[dim0];
|
arg_data(4)))[dim0];
|
||||||
}
|
}
|
||||||
|
int var_myvar2_size() const {
|
||||||
|
return 5 * sizeof(tensorflow::int32);
|
||||||
|
}
|
||||||
|
int var_myvar2_count() const {
|
||||||
|
return 5;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Number of buffers for the compiled computation.
|
// Number of buffers for the compiled computation.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user