[XLA:AOT] Add a readonly flag to allow resource variables that are never assigned
It's a bit unfortunate that we have to leave this decision to the user, but I don't see a better way of handling variables that are never assigned to. XLA will just prune away the output for them so we have to account for them early. The API follows the in-place API, but we make the setter const and don't generate getters. PiperOrigin-RevId: 240107733
This commit is contained in:
parent
ec6f4fa16c
commit
26aa3d515d
@ -213,7 +213,11 @@ Status GenResultMethods(const tf2xla::Config& config,
|
|||||||
return errors::Internal("codegen requires the XLA result to be a tuple");
|
return errors::Internal("codegen requires the XLA result to be a tuple");
|
||||||
}
|
}
|
||||||
size_t num_results = ps.result().tuple_shapes_size();
|
size_t num_results = ps.result().tuple_shapes_size();
|
||||||
if (config.fetch_size() + config.variable_size() != num_results) {
|
int readonly_variables = absl::c_count_if(
|
||||||
|
config.variable(),
|
||||||
|
[](const tf2xla::Variable& var) { return var.readonly(); });
|
||||||
|
if (config.fetch_size() + config.variable_size() - readonly_variables !=
|
||||||
|
num_results) {
|
||||||
return errors::InvalidArgument("mismatch between fetch_size(",
|
return errors::InvalidArgument("mismatch between fetch_size(",
|
||||||
config.fetch_size(), ")+variable_size(",
|
config.fetch_size(), ")+variable_size(",
|
||||||
config.variable_size(), ") and tuple_size(",
|
config.variable_size(), ") and tuple_size(",
|
||||||
@ -256,15 +260,17 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
|
AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
|
||||||
const string code = R"(
|
const string code = R"(
|
||||||
void set_var_{{NAME}}_data({{TYPE}}* data) {
|
void set_var_{{NAME}}_data({{MAYBE_CONST}}{{TYPE}}* data) {
|
||||||
set_arg_data({{I}}, data);
|
set_arg_data({{I}}, data);
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
const tf2xla::Variable& var = config.variable(i - config.feed_size());
|
const tf2xla::Variable& var = config.variable(i - config.feed_size());
|
||||||
|
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
|
||||||
*methods += RewriteWithName(
|
*methods += RewriteWithName(
|
||||||
var.name().empty() ? var.node_name() : var.name(), code, rewrites);
|
var.name().empty() ? var.node_name() : var.name(), code, rewrites);
|
||||||
}
|
}
|
||||||
size_t num_results = ps.result().tuple_shapes_size();
|
size_t num_results = ps.result().tuple_shapes_size();
|
||||||
|
int variable_num = -1;
|
||||||
for (int i = config.fetch_size(); i < num_results; ++i) {
|
for (int i = config.fetch_size(); i < num_results; ++i) {
|
||||||
std::vector<std::pair<string, string>> rewrites;
|
std::vector<std::pair<string, string>> rewrites;
|
||||||
TF_RETURN_IF_ERROR(AddRewritesForShape(
|
TF_RETURN_IF_ERROR(AddRewritesForShape(
|
||||||
@ -285,7 +291,10 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
|||||||
result_data({{I}}))){{INDICES}};
|
result_data({{I}}))){{INDICES}};
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
const tf2xla::Variable& var = config.variable(i - config.fetch_size());
|
do {
|
||||||
|
++variable_num;
|
||||||
|
} while (config.variable(variable_num).readonly());
|
||||||
|
const tf2xla::Variable& var = config.variable(variable_num);
|
||||||
*methods += RewriteWithName(
|
*methods += RewriteWithName(
|
||||||
var.name().empty() ? var.node_name() : var.name(), code, rewrites);
|
var.name().empty() ? var.node_name() : var.name(), code, rewrites);
|
||||||
}
|
}
|
||||||
|
@ -175,14 +175,19 @@ TEST(CodegenTest, Golden) {
|
|||||||
fetch->mutable_id()->set_node_name("fetch0");
|
fetch->mutable_id()->set_node_name("fetch0");
|
||||||
fetch->set_name("myfetch");
|
fetch->set_name("myfetch");
|
||||||
tf2xla::Variable* variable = config.add_variable();
|
tf2xla::Variable* variable = config.add_variable();
|
||||||
variable->set_node_name("myvar");
|
variable->set_node_name("myvar_readonly");
|
||||||
variable->mutable_shape()->add_dim()->set_size(1);
|
variable->mutable_shape()->add_dim()->set_size(1);
|
||||||
variable->set_type(DT_FLOAT);
|
variable->set_type(DT_FLOAT);
|
||||||
|
variable->set_readonly(true);
|
||||||
tf2xla::Variable* variable2 = config.add_variable();
|
tf2xla::Variable* variable2 = config.add_variable();
|
||||||
variable2->set_node_name("my/var");
|
variable2->set_node_name("myvar");
|
||||||
variable2->set_name("myvar2");
|
variable2->mutable_shape()->add_dim()->set_size(1);
|
||||||
variable2->mutable_shape()->add_dim()->set_size(5);
|
variable2->set_type(DT_FLOAT);
|
||||||
variable2->set_type(DT_INT32);
|
tf2xla::Variable* variable3 = config.add_variable();
|
||||||
|
variable3->set_node_name("my/var");
|
||||||
|
variable3->set_name("myvar2");
|
||||||
|
variable3->mutable_shape()->add_dim()->set_size(5);
|
||||||
|
variable3->set_type(DT_INT32);
|
||||||
CompileResult compile_result;
|
CompileResult compile_result;
|
||||||
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
|
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
|
||||||
{},
|
{},
|
||||||
@ -198,6 +203,7 @@ TEST(CodegenTest, Golden) {
|
|||||||
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
|
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
|
||||||
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
|
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
|
||||||
xla::ShapeUtil::MakeShape(xla::F32, {1}),
|
xla::ShapeUtil::MakeShape(xla::F32, {1}),
|
||||||
|
xla::ShapeUtil::MakeShape(xla::F32, {1}),
|
||||||
xla::ShapeUtil::MakeShape(xla::S32, {5}),
|
xla::ShapeUtil::MakeShape(xla::S32, {5}),
|
||||||
},
|
},
|
||||||
xla::ShapeUtil::MakeTupleShape({
|
xla::ShapeUtil::MakeTupleShape({
|
||||||
|
@ -52,7 +52,7 @@ namespace bar {
|
|||||||
// is guaranteed that no thread may call a non-const method.
|
// is guaranteed that no thread may call a non-const method.
|
||||||
//
|
//
|
||||||
// The logical function signature is:
|
// The logical function signature is:
|
||||||
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
|
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
|
||||||
//
|
//
|
||||||
// Memory stats:
|
// Memory stats:
|
||||||
// arg bytes total: 104
|
// arg bytes total: 104
|
||||||
@ -228,14 +228,18 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
// with dim indices specifying which value. No bounds checking is performed
|
// with dim indices specifying which value. No bounds checking is performed
|
||||||
// on dim indices.
|
// on dim indices.
|
||||||
|
|
||||||
void set_var_myvar_data(float* data) {
|
void set_var_myvar_readonly_data(const float* data) {
|
||||||
set_arg_data(2, data);
|
set_arg_data(2, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_var_myvar2_data(tensorflow::int32* data) {
|
void set_var_myvar_data(float* data) {
|
||||||
set_arg_data(3, data);
|
set_arg_data(3, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_var_myvar2_data(tensorflow::int32* data) {
|
||||||
|
set_arg_data(4, data);
|
||||||
|
}
|
||||||
|
|
||||||
float* var_myvar_data() {
|
float* var_myvar_data() {
|
||||||
return static_cast<float*>(result_data(1));
|
return static_cast<float*>(result_data(1));
|
||||||
}
|
}
|
||||||
@ -309,7 +313,7 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
static const xla::ProgramShapeProto* StaticProgramShape() {
|
static const xla::ProgramShapeProto* StaticProgramShape() {
|
||||||
static const xla::ProgramShapeProto* kShape = []() {
|
static const xla::ProgramShapeProto* kShape = []() {
|
||||||
xla::ProgramShapeProto* proto = new xla::ProgramShapeProto;
|
xla::ProgramShapeProto* proto = new xla::ProgramShapeProto;
|
||||||
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 132);
|
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 149);
|
||||||
return proto;
|
return proto;
|
||||||
}();
|
}();
|
||||||
return kShape;
|
return kShape;
|
||||||
|
Binary file not shown.
@ -159,10 +159,11 @@ def tfvariable(_):
|
|||||||
|
|
||||||
def tfvariable_sequential_updates(_):
|
def tfvariable_sequential_updates(_):
|
||||||
x = variables.Variable(1.0, name='x')
|
x = variables.Variable(1.0, name='x')
|
||||||
|
y = variables.Variable(1.0, name='y')
|
||||||
updates = control_flow_ops.no_op()
|
updates = control_flow_ops.no_op()
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
with ops.control_dependencies([updates]):
|
with ops.control_dependencies([updates]):
|
||||||
x_val = x.read_value() + 1.0
|
x_val = x.read_value() + y
|
||||||
updates = x.assign_sub(0.1 * x_val)
|
updates = x.assign_sub(0.1 * x_val)
|
||||||
|
|
||||||
array_ops.identity(updates, name='result')
|
array_ops.identity(updates, name='result')
|
||||||
|
@ -7,3 +7,9 @@ variable {
|
|||||||
node_name: "x"
|
node_name: "x"
|
||||||
type: DT_FLOAT
|
type: DT_FLOAT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
variable {
|
||||||
|
node_name: "y"
|
||||||
|
type: DT_FLOAT
|
||||||
|
readonly: true
|
||||||
|
}
|
||||||
|
@ -502,20 +502,22 @@ TEST(TFCompileTest, VariableSequentialUpdates) {
|
|||||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||||
|
|
||||||
// This implements the recursion:
|
// This implements the recursion:
|
||||||
// x[0] = 1.0
|
// x[0] = 2.0
|
||||||
// x[n+1] = x[n] - 0.1*(x[n-1] + 1.0)
|
// x[n+1] = x[n] - 0.1*(x[n-1] + 1.0)
|
||||||
VariableSequentialUpdatesComp fn;
|
VariableSequentialUpdatesComp fn;
|
||||||
float x = 1;
|
float x = 2;
|
||||||
|
float y = 1;
|
||||||
fn.set_var_x_data(&x);
|
fn.set_var_x_data(&x);
|
||||||
|
fn.set_var_y_data(&y);
|
||||||
|
|
||||||
fn.set_thread_pool(&device);
|
fn.set_thread_pool(&device);
|
||||||
// First calculate x[3]
|
// First calculate x[3]
|
||||||
fn.Run();
|
fn.Run();
|
||||||
EXPECT_NEAR(x, 0.458f, 1e-6);
|
EXPECT_NEAR(x, 1.187f, 1e-6);
|
||||||
|
|
||||||
// Then calculate x[6]
|
// Then calculate x[6]
|
||||||
fn.Run();
|
fn.Run();
|
||||||
EXPECT_NEAR(x, 0.062882f, 1e-6);
|
EXPECT_NEAR(x, 0.594322f, 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TFCompileTest, AssertEqAndReturnDiff) {
|
TEST(TFCompileTest, AssertEqAndReturnDiff) {
|
||||||
|
@ -278,12 +278,14 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
|
|||||||
arg.initialized = true;
|
arg.initialized = true;
|
||||||
xla_args.push_back(std::move(arg));
|
xla_args.push_back(std::move(arg));
|
||||||
|
|
||||||
// We want to alias the input and output of the variable, so the updates are
|
if (!variable.readonly()) {
|
||||||
// carried out in-place.
|
// We want to alias the input and output of the variable, so the updates
|
||||||
xla_aliases.push_back({/*output_index=*/{output_num},
|
// are carried out in-place.
|
||||||
/*param_number=*/input_num, /*param_index=*/{}});
|
xla_aliases.push_back({/*output_index=*/{output_num},
|
||||||
|
/*param_number=*/input_num, /*param_index=*/{}});
|
||||||
|
++output_num;
|
||||||
|
}
|
||||||
++input_num;
|
++input_num;
|
||||||
++output_num;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile the graph into an XLA computation.
|
// Compile the graph into an XLA computation.
|
||||||
@ -324,6 +326,24 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
|
|||||||
" constant results. The configuration of "
|
" constant results. The configuration of "
|
||||||
"the output args (i.e. fetch ids) is probably wrong.");
|
"the output args (i.e. fetch ids) is probably wrong.");
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
// Verify that the readonly bits on variables are set correctly by the user.
|
||||||
|
std::vector<bool> updated_inputs(xla_args.size());
|
||||||
|
for (const XlaCompiler::ResourceUpdate& update : result.resource_updates) {
|
||||||
|
updated_inputs[update.input_index] = true;
|
||||||
|
}
|
||||||
|
int64 input_index = xla_args.size() - config.variable_size();
|
||||||
|
for (const tf2xla::Variable& variable : config.variable()) {
|
||||||
|
if (variable.readonly() == updated_inputs[input_index]) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Variable \"", variable.node_name(), "\" is marked as ",
|
||||||
|
variable.readonly() ? "" : "not ", "readonly, but is ",
|
||||||
|
updated_inputs[input_index] ? "" : "not ",
|
||||||
|
"modified by the computation.");
|
||||||
|
}
|
||||||
|
++input_index;
|
||||||
|
}
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
package tensorflow.tf2xla;
|
package tensorflow.tf2xla;
|
||||||
|
|
||||||
|
import "tensorflow/core/framework/tensor_shape.proto";
|
||||||
|
import "tensorflow/core/framework/types.proto";
|
||||||
|
|
||||||
option cc_enable_arenas = true;
|
option cc_enable_arenas = true;
|
||||||
option java_outer_classname = "Tf2XlaProtos";
|
option java_outer_classname = "Tf2XlaProtos";
|
||||||
option java_multiple_files = true;
|
option java_multiple_files = true;
|
||||||
option java_package = "org.tensorflow.tf2xla";
|
option java_package = "org.tensorflow.tf2xla";
|
||||||
|
|
||||||
import "tensorflow/core/framework/tensor_shape.proto";
|
|
||||||
import "tensorflow/core/framework/types.proto";
|
|
||||||
|
|
||||||
// TensorId identifies a tensor in a TensorFlow graph, by specifying the output
|
// TensorId identifies a tensor in a TensorFlow graph, by specifying the output
|
||||||
// index of a particular node in the graph. If the output of the named node
|
// index of a particular node in the graph. If the output of the named node
|
||||||
// feeds into other node(s), this corresponds to one or more edges. Otherwise
|
// feeds into other node(s), this corresponds to one or more edges. Otherwise
|
||||||
@ -16,7 +17,7 @@ import "tensorflow/core/framework/types.proto";
|
|||||||
message TensorId {
|
message TensorId {
|
||||||
string node_name = 1;
|
string node_name = 1;
|
||||||
int64 output_index = 2;
|
int64 output_index = 2;
|
||||||
};
|
}
|
||||||
|
|
||||||
// Feed represents a single feed tensor in the graph, which corresponds to an
|
// Feed represents a single feed tensor in the graph, which corresponds to an
|
||||||
// input argument for the generated computation.
|
// input argument for the generated computation.
|
||||||
@ -30,14 +31,14 @@ message Feed {
|
|||||||
// not linked into the binary, then the type cannot be inferred from the node;
|
// not linked into the binary, then the type cannot be inferred from the node;
|
||||||
// in this case, the type should be set here.
|
// in this case, the type should be set here.
|
||||||
DataType type = 4;
|
DataType type = 4;
|
||||||
};
|
}
|
||||||
|
|
||||||
// Fetch represents a single fetch tensor in the graph, which corresponds to an
|
// Fetch represents a single fetch tensor in the graph, which corresponds to an
|
||||||
// output argument for the generated computation.
|
// output argument for the generated computation.
|
||||||
message Fetch {
|
message Fetch {
|
||||||
TensorId id = 1;
|
TensorId id = 1;
|
||||||
string name = 2; // Optional name for generated code.
|
string name = 2; // Optional name for generated code.
|
||||||
};
|
}
|
||||||
|
|
||||||
// Variable represents a resource variable with the given name, shape and type.
|
// Variable represents a resource variable with the given name, shape and type.
|
||||||
message Variable {
|
message Variable {
|
||||||
@ -46,6 +47,10 @@ message Variable {
|
|||||||
2; // Optional name for generated code. If empty, node_name will be used.
|
2; // Optional name for generated code. If empty, node_name will be used.
|
||||||
TensorShapeProto shape = 3;
|
TensorShapeProto shape = 3;
|
||||||
DataType type = 4;
|
DataType type = 4;
|
||||||
|
|
||||||
|
// Flag for variables that are never assigned. Assigments to a read-only
|
||||||
|
// variable or unassigned variables that are not read-only are invalid.
|
||||||
|
bool readonly = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config represents configuration information for tf2xla conversion.
|
// Config represents configuration information for tf2xla conversion.
|
||||||
@ -58,4 +63,4 @@ message Config {
|
|||||||
repeated Fetch fetch = 2;
|
repeated Fetch fetch = 2;
|
||||||
// Each variable is a named input and output of the generated computation.
|
// Each variable is a named input and output of the generated computation.
|
||||||
repeated Variable variable = 3;
|
repeated Variable variable = 3;
|
||||||
};
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user