Add named size and count methods for arg, result and var methods to AOT models.
PiperOrigin-RevId: 310375046 Change-Id: I3fd5c7fbdcfe141449a0a4d6827f6e5fe14b4e0b
This commit is contained in:
parent
0cc3e612bd
commit
96f4a930db
tensorflow/compiler/aot
@ -131,6 +131,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
||||
TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
|
||||
std::vector<string> dim_vars;
|
||||
string dim_sizes, indices;
|
||||
int count = 1;
|
||||
if (shape.rank() == 0 ||
|
||||
(shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
|
||||
dim_sizes = "[1]";
|
||||
@ -140,6 +141,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
||||
dim_vars.push_back(absl::StrCat("size_t dim", dim));
|
||||
dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
|
||||
indices += absl::StrCat("[dim", dim, "]");
|
||||
count *= shape.dimensions(dim);
|
||||
}
|
||||
}
|
||||
rewrites->push_back({"{{I}}", absl::StrCat(i)});
|
||||
@ -147,6 +149,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
||||
rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
|
||||
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
|
||||
rewrites->push_back({"{{INDICES}}", indices});
|
||||
rewrites->push_back({"{{COUNT}}", absl::StrCat(count)});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -199,6 +202,12 @@ Status GenArgMethods(const tf2xla::Config& config,
|
||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||
arg_data({{I}}))){{INDICES}};
|
||||
}
|
||||
int arg{{NAME}}_size() const {
|
||||
return {{COUNT}} * sizeof({{TYPE}});
|
||||
}
|
||||
int arg{{NAME}}_count() const {
|
||||
return {{COUNT}};
|
||||
}
|
||||
)";
|
||||
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
||||
if (!config.feed(i).name().empty()) {
|
||||
@ -246,6 +255,12 @@ Status GenResultMethods(const tf2xla::Config& config,
|
||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||
result_data({{I}}))){{INDICES}};
|
||||
}
|
||||
int result{{NAME}}_size() const {
|
||||
return {{COUNT}} * sizeof({{TYPE}});
|
||||
}
|
||||
int result{{NAME}}_count() const {
|
||||
return {{COUNT}};
|
||||
}
|
||||
)";
|
||||
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
||||
if (!config.fetch(i).name().empty()) {
|
||||
@ -281,6 +296,12 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||
arg_data({{I}}))){{INDICES}};
|
||||
}
|
||||
int var_{{NAME}}_size() const {
|
||||
return {{COUNT}} * sizeof({{TYPE}});
|
||||
}
|
||||
int var_{{NAME}}_count() const {
|
||||
return {{COUNT}};
|
||||
}
|
||||
)";
|
||||
const tf2xla::Variable& var = config.variable(i - config.feed_size());
|
||||
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
|
||||
|
@ -138,6 +138,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const float(*)[1][2]>(
|
||||
arg_data(0)))[dim0][dim1];
|
||||
}
|
||||
int arg0_size() const {
|
||||
return 2 * sizeof(float);
|
||||
}
|
||||
int arg0_count() const {
|
||||
return 2;
|
||||
}
|
||||
|
||||
void set_arg_myfeed_data(const void* data) {
|
||||
set_arg_data(0, data);
|
||||
@ -156,6 +162,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const float(*)[1][2]>(
|
||||
arg_data(0)))[dim0][dim1];
|
||||
}
|
||||
int arg_myfeed_size() const {
|
||||
return 2 * sizeof(float);
|
||||
}
|
||||
int arg_myfeed_count() const {
|
||||
return 2;
|
||||
}
|
||||
|
||||
void set_arg1_data(const void* data) {
|
||||
set_arg_data(1, data);
|
||||
@ -174,6 +186,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const tensorflow::int64(*)[3][4]>(
|
||||
arg_data(1)))[dim0][dim1];
|
||||
}
|
||||
int arg1_size() const {
|
||||
return 12 * sizeof(tensorflow::int64);
|
||||
}
|
||||
int arg1_count() const {
|
||||
return 12;
|
||||
}
|
||||
|
||||
// Result methods for managing output buffers. Buffers are in row-major order.
|
||||
// Must only be called after a successful Run call. There is a set of methods
|
||||
@ -204,6 +222,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
||||
result_data(0)))[dim0][dim1];
|
||||
}
|
||||
int result0_size() const {
|
||||
return 30 * sizeof(tensorflow::uint32);
|
||||
}
|
||||
int result0_count() const {
|
||||
return 30;
|
||||
}
|
||||
|
||||
tensorflow::uint32* result_myfetch_data() {
|
||||
return static_cast<tensorflow::uint32*>(result_data(0));
|
||||
@ -219,6 +243,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
||||
result_data(0)))[dim0][dim1];
|
||||
}
|
||||
int result_myfetch_size() const {
|
||||
return 30 * sizeof(tensorflow::uint32);
|
||||
}
|
||||
int result_myfetch_count() const {
|
||||
return 30;
|
||||
}
|
||||
|
||||
// Methods for managing variable buffers. Buffers are in row-major order.
|
||||
//
|
||||
@ -261,6 +291,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const float(*)[1]>(
|
||||
arg_data(2)))[0];
|
||||
}
|
||||
int var_myvar_readonly_size() const {
|
||||
return 1 * sizeof(float);
|
||||
}
|
||||
int var_myvar_readonly_count() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
void set_var_myvar_data(float* data) {
|
||||
set_arg_data(3, data);
|
||||
@ -279,6 +315,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const float(*)[1]>(
|
||||
arg_data(3)))[0];
|
||||
}
|
||||
int var_myvar_size() const {
|
||||
return 1 * sizeof(float);
|
||||
}
|
||||
int var_myvar_count() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
void set_var_myvar2_data(tensorflow::int32* data) {
|
||||
set_arg_data(4, data);
|
||||
@ -297,6 +339,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const tensorflow::int32(*)[5]>(
|
||||
arg_data(4)))[dim0];
|
||||
}
|
||||
int var_myvar2_size() const {
|
||||
return 5 * sizeof(tensorflow::int32);
|
||||
}
|
||||
int var_myvar2_count() const {
|
||||
return 5;
|
||||
}
|
||||
|
||||
private:
|
||||
// Number of buffers for the compiled computation.
|
||||
|
Loading…
Reference in New Issue
Block a user