[XLA:AOT] Add a readonly flag to allow resource variables that are never assigned
It's a bit unfortunate that we have to leave this decision to the user, but I don't see a better way of handling variables that are never assigned to. XLA will just prune away the output for them so we have to account for them early. The API follows the in-place API, but we make the setter const and don't generate getters. PiperOrigin-RevId: 240107733
This commit is contained in:
parent
ec6f4fa16c
commit
26aa3d515d
@ -213,7 +213,11 @@ Status GenResultMethods(const tf2xla::Config& config,
|
||||
return errors::Internal("codegen requires the XLA result to be a tuple");
|
||||
}
|
||||
size_t num_results = ps.result().tuple_shapes_size();
|
||||
if (config.fetch_size() + config.variable_size() != num_results) {
|
||||
int readonly_variables = absl::c_count_if(
|
||||
config.variable(),
|
||||
[](const tf2xla::Variable& var) { return var.readonly(); });
|
||||
if (config.fetch_size() + config.variable_size() - readonly_variables !=
|
||||
num_results) {
|
||||
return errors::InvalidArgument("mismatch between fetch_size(",
|
||||
config.fetch_size(), ")+variable_size(",
|
||||
config.variable_size(), ") and tuple_size(",
|
||||
@ -256,15 +260,17 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
|
||||
const string code = R"(
|
||||
void set_var_{{NAME}}_data({{TYPE}}* data) {
|
||||
void set_var_{{NAME}}_data({{MAYBE_CONST}}{{TYPE}}* data) {
|
||||
set_arg_data({{I}}, data);
|
||||
}
|
||||
)";
|
||||
const tf2xla::Variable& var = config.variable(i - config.feed_size());
|
||||
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
|
||||
*methods += RewriteWithName(
|
||||
var.name().empty() ? var.node_name() : var.name(), code, rewrites);
|
||||
}
|
||||
size_t num_results = ps.result().tuple_shapes_size();
|
||||
int variable_num = -1;
|
||||
for (int i = config.fetch_size(); i < num_results; ++i) {
|
||||
std::vector<std::pair<string, string>> rewrites;
|
||||
TF_RETURN_IF_ERROR(AddRewritesForShape(
|
||||
@ -285,7 +291,10 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
||||
result_data({{I}}))){{INDICES}};
|
||||
}
|
||||
)";
|
||||
const tf2xla::Variable& var = config.variable(i - config.fetch_size());
|
||||
do {
|
||||
++variable_num;
|
||||
} while (config.variable(variable_num).readonly());
|
||||
const tf2xla::Variable& var = config.variable(variable_num);
|
||||
*methods += RewriteWithName(
|
||||
var.name().empty() ? var.node_name() : var.name(), code, rewrites);
|
||||
}
|
||||
|
@ -175,14 +175,19 @@ TEST(CodegenTest, Golden) {
|
||||
fetch->mutable_id()->set_node_name("fetch0");
|
||||
fetch->set_name("myfetch");
|
||||
tf2xla::Variable* variable = config.add_variable();
|
||||
variable->set_node_name("myvar");
|
||||
variable->set_node_name("myvar_readonly");
|
||||
variable->mutable_shape()->add_dim()->set_size(1);
|
||||
variable->set_type(DT_FLOAT);
|
||||
variable->set_readonly(true);
|
||||
tf2xla::Variable* variable2 = config.add_variable();
|
||||
variable2->set_node_name("my/var");
|
||||
variable2->set_name("myvar2");
|
||||
variable2->mutable_shape()->add_dim()->set_size(5);
|
||||
variable2->set_type(DT_INT32);
|
||||
variable2->set_node_name("myvar");
|
||||
variable2->mutable_shape()->add_dim()->set_size(1);
|
||||
variable2->set_type(DT_FLOAT);
|
||||
tf2xla::Variable* variable3 = config.add_variable();
|
||||
variable3->set_node_name("my/var");
|
||||
variable3->set_name("myvar2");
|
||||
variable3->mutable_shape()->add_dim()->set_size(5);
|
||||
variable3->set_type(DT_INT32);
|
||||
CompileResult compile_result;
|
||||
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
|
||||
{},
|
||||
@ -198,6 +203,7 @@ TEST(CodegenTest, Golden) {
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
|
||||
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {1}),
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {1}),
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {5}),
|
||||
},
|
||||
xla::ShapeUtil::MakeTupleShape({
|
||||
|
@ -52,7 +52,7 @@ namespace bar {
|
||||
// is guaranteed that no thread may call a non-const method.
|
||||
//
|
||||
// The logical function signature is:
|
||||
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
|
||||
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
|
||||
//
|
||||
// Memory stats:
|
||||
// arg bytes total: 104
|
||||
@ -228,14 +228,18 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
// with dim indices specifying which value. No bounds checking is performed
|
||||
// on dim indices.
|
||||
|
||||
void set_var_myvar_data(float* data) {
|
||||
void set_var_myvar_readonly_data(const float* data) {
|
||||
set_arg_data(2, data);
|
||||
}
|
||||
|
||||
void set_var_myvar2_data(tensorflow::int32* data) {
|
||||
void set_var_myvar_data(float* data) {
|
||||
set_arg_data(3, data);
|
||||
}
|
||||
|
||||
void set_var_myvar2_data(tensorflow::int32* data) {
|
||||
set_arg_data(4, data);
|
||||
}
|
||||
|
||||
float* var_myvar_data() {
|
||||
return static_cast<float*>(result_data(1));
|
||||
}
|
||||
@ -309,7 +313,7 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
static const xla::ProgramShapeProto* StaticProgramShape() {
|
||||
static const xla::ProgramShapeProto* kShape = []() {
|
||||
xla::ProgramShapeProto* proto = new xla::ProgramShapeProto;
|
||||
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 132);
|
||||
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 149);
|
||||
return proto;
|
||||
}();
|
||||
return kShape;
|
||||
|
Binary file not shown.
@ -159,10 +159,11 @@ def tfvariable(_):
|
||||
|
||||
def tfvariable_sequential_updates(_):
|
||||
x = variables.Variable(1.0, name='x')
|
||||
y = variables.Variable(1.0, name='y')
|
||||
updates = control_flow_ops.no_op()
|
||||
for _ in range(3):
|
||||
with ops.control_dependencies([updates]):
|
||||
x_val = x.read_value() + 1.0
|
||||
x_val = x.read_value() + y
|
||||
updates = x.assign_sub(0.1 * x_val)
|
||||
|
||||
array_ops.identity(updates, name='result')
|
||||
|
@ -7,3 +7,9 @@ variable {
|
||||
node_name: "x"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
|
||||
variable {
|
||||
node_name: "y"
|
||||
type: DT_FLOAT
|
||||
readonly: true
|
||||
}
|
||||
|
@ -502,20 +502,22 @@ TEST(TFCompileTest, VariableSequentialUpdates) {
|
||||
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
|
||||
|
||||
// This implements the recursion:
|
||||
// x[0] = 1.0
|
||||
// x[0] = 2.0
|
||||
// x[n+1] = x[n] - 0.1*(x[n-1] + 1.0)
|
||||
VariableSequentialUpdatesComp fn;
|
||||
float x = 1;
|
||||
float x = 2;
|
||||
float y = 1;
|
||||
fn.set_var_x_data(&x);
|
||||
fn.set_var_y_data(&y);
|
||||
|
||||
fn.set_thread_pool(&device);
|
||||
// First calculate x[3]
|
||||
fn.Run();
|
||||
EXPECT_NEAR(x, 0.458f, 1e-6);
|
||||
EXPECT_NEAR(x, 1.187f, 1e-6);
|
||||
|
||||
// Then calculate x[6]
|
||||
fn.Run();
|
||||
EXPECT_NEAR(x, 0.062882f, 1e-6);
|
||||
EXPECT_NEAR(x, 0.594322f, 1e-6);
|
||||
}
|
||||
|
||||
TEST(TFCompileTest, AssertEqAndReturnDiff) {
|
||||
|
@ -278,12 +278,14 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
|
||||
arg.initialized = true;
|
||||
xla_args.push_back(std::move(arg));
|
||||
|
||||
// We want to alias the input and output of the variable, so the updates are
|
||||
// carried out in-place.
|
||||
xla_aliases.push_back({/*output_index=*/{output_num},
|
||||
/*param_number=*/input_num, /*param_index=*/{}});
|
||||
if (!variable.readonly()) {
|
||||
// We want to alias the input and output of the variable, so the updates
|
||||
// are carried out in-place.
|
||||
xla_aliases.push_back({/*output_index=*/{output_num},
|
||||
/*param_number=*/input_num, /*param_index=*/{}});
|
||||
++output_num;
|
||||
}
|
||||
++input_num;
|
||||
++output_num;
|
||||
}
|
||||
|
||||
// Compile the graph into an XLA computation.
|
||||
@ -324,6 +326,24 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
|
||||
" constant results. The configuration of "
|
||||
"the output args (i.e. fetch ids) is probably wrong.");
|
||||
}
|
||||
{
|
||||
// Verify that the readonly bits on variables are set correctly by the user.
|
||||
std::vector<bool> updated_inputs(xla_args.size());
|
||||
for (const XlaCompiler::ResourceUpdate& update : result.resource_updates) {
|
||||
updated_inputs[update.input_index] = true;
|
||||
}
|
||||
int64 input_index = xla_args.size() - config.variable_size();
|
||||
for (const tf2xla::Variable& variable : config.variable()) {
|
||||
if (variable.readonly() == updated_inputs[input_index]) {
|
||||
return errors::InvalidArgument(
|
||||
"Variable \"", variable.node_name(), "\" is marked as ",
|
||||
variable.readonly() ? "" : "not ", "readonly, but is ",
|
||||
updated_inputs[input_index] ? "" : "not ",
|
||||
"modified by the computation.");
|
||||
}
|
||||
++input_index;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1,14 +1,15 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.tf2xla;
|
||||
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "Tf2XlaProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.tf2xla";
|
||||
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
// TensorId identifies a tensor in a TensorFlow graph, by specifying the output
|
||||
// index of a particular node in the graph. If the output of the named node
|
||||
// feeds into other node(s), this corresponds to one or more edges. Otherwise
|
||||
@ -16,7 +17,7 @@ import "tensorflow/core/framework/types.proto";
|
||||
message TensorId {
|
||||
string node_name = 1;
|
||||
int64 output_index = 2;
|
||||
};
|
||||
}
|
||||
|
||||
// Feed represents a single feed tensor in the graph, which corresponds to an
|
||||
// input argument for the generated computation.
|
||||
@ -30,14 +31,14 @@ message Feed {
|
||||
// not linked into the binary, then the type cannot be inferred from the node;
|
||||
// in this case, the type should be set here.
|
||||
DataType type = 4;
|
||||
};
|
||||
}
|
||||
|
||||
// Fetch represents a single fetch tensor in the graph, which corresponds to an
|
||||
// output argument for the generated computation.
|
||||
message Fetch {
|
||||
TensorId id = 1;
|
||||
string name = 2; // Optional name for generated code.
|
||||
};
|
||||
}
|
||||
|
||||
// Variable represents a resource variable with the given name, shape and type.
|
||||
message Variable {
|
||||
@ -46,6 +47,10 @@ message Variable {
|
||||
2; // Optional name for generated code. If empty, node_name will be used.
|
||||
TensorShapeProto shape = 3;
|
||||
DataType type = 4;
|
||||
|
||||
// Flag for variables that are never assigned. Assigments to a read-only
|
||||
// variable or unassigned variables that are not read-only are invalid.
|
||||
bool readonly = 5;
|
||||
}
|
||||
|
||||
// Config represents configuration information for tf2xla conversion.
|
||||
@ -58,4 +63,4 @@ message Config {
|
||||
repeated Fetch fetch = 2;
|
||||
// Each variable is a named input and output of the generated computation.
|
||||
repeated Variable variable = 3;
|
||||
};
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user