[TF/XLA] Fixup numbering of XLA parameters used for aliasing
Previously, the XLA argument parameter was incorrectly assumed to be corresponding to the index in the vector of `XlaCompiler::Argument`. This is not correct, since not all `XlaCompiler::Argument`s become arguments to the compiler: notably, constants and uninitialized resource variables do not. PiperOrigin-RevId: 321709603 Change-Id: I730fd6385949c360b2b831318a5b59c08f8362ef
This commit is contained in:
parent
4f7f17e469
commit
b440bbb40f
@ -168,7 +168,7 @@ Status BuildComputation(
|
||||
int* num_computation_outputs, int* num_nonconst_outputs,
|
||||
std::vector<XlaCompiler::OutputDescription>* outputs,
|
||||
std::vector<XlaCompiler::ResourceUpdate>* resource_updates,
|
||||
xla::Shape* output_shape) {
|
||||
xla::Shape* output_shape, absl::Span<int const> input_mapping) {
|
||||
// Attach a common operator name as metadata. This has no semantic effect — it
|
||||
// merely makes the HLO graph more readable when visualized via TensorBoard,
|
||||
// since TensorBoard forms groups out of operators with similar names.
|
||||
@ -268,6 +268,11 @@ Status BuildComputation(
|
||||
return a->arg_num() < b->arg_num();
|
||||
});
|
||||
|
||||
absl::flat_hash_map<int, int> argument_to_xla_arg;
|
||||
for (int xla_arg = 0; xla_arg < input_mapping.size(); xla_arg++) {
|
||||
argument_to_xla_arg[input_mapping[xla_arg]] = xla_arg;
|
||||
}
|
||||
|
||||
std::vector<xla::XlaBuilder::InputOutputAlias> aliases;
|
||||
for (const XlaResource* resource : arg_resources) {
|
||||
DCHECK_LT(resource->arg_num(), args.size());
|
||||
@ -290,19 +295,20 @@ Status BuildComputation(
|
||||
update.type = resource->type();
|
||||
update.shape = resource->shape();
|
||||
update.modified = modified;
|
||||
int param_num = use_tuple_arg ? 0 : update.input_index;
|
||||
if (is_entry_computation &&
|
||||
arg.resource_kind != XlaResource::kTensorArray &&
|
||||
alias_resource_update) {
|
||||
alias_resource_update && argument_to_xla_arg.count(param_num)) {
|
||||
// Assuming tuple arg and results are used.
|
||||
xla::ShapeIndex param_index =
|
||||
use_tuple_arg ? xla::ShapeIndex({update.input_index})
|
||||
: xla::ShapeIndex{};
|
||||
int param_number = use_tuple_arg ? 0 : update.input_index;
|
||||
int xla_param_num = argument_to_xla_arg[param_num];
|
||||
int64 output_index_num = elems.size();
|
||||
xla::ShapeIndex output_index = xla::ShapeIndex({output_index_num});
|
||||
VLOG(3) << "Storing alias: " << output_index.ToString() << ": ("
|
||||
<< param_number << ", " << param_index.ToString() << ")";
|
||||
aliases.push_back({output_index, param_number, param_index});
|
||||
<< xla_param_num << ", " << param_index.ToString() << ")";
|
||||
aliases.push_back({output_index, xla_param_num, param_index});
|
||||
}
|
||||
for (const auto& grad : resource->tensor_array_gradients()) {
|
||||
update.tensor_array_gradients_accessed.insert(grad.first);
|
||||
@ -1315,7 +1321,8 @@ Status XlaCompiler::CompileGraph(
|
||||
options.always_return_tuple, options.use_tuple_arg,
|
||||
options.alias_resource_update, &builder, result->computation.get(),
|
||||
&num_computation_outputs, &num_nonconst_outputs, &result->outputs,
|
||||
&result->resource_updates, &result->xla_output_shape));
|
||||
&result->resource_updates, &result->xla_output_shape,
|
||||
result->input_mapping));
|
||||
|
||||
VLOG(2) << "Outputs: total: " << context->retvals().size()
|
||||
<< " nonconstant: " << num_nonconst_outputs;
|
||||
|
@ -1856,5 +1856,46 @@ TEST_F(XlaCompilerTest, DoNotConstantFoldShapeOp) {
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
|
||||
}
|
||||
|
||||
TEST_F(XlaCompilerTest, AliasResourceUpdates) {
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto a = ops::Const<int32>(scope.WithOpName("A"), {1, 2});
|
||||
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
|
||||
auto write = ops::AssignAddVariableOp(scope, var, a);
|
||||
auto read = ops::ReadVariableOp(
|
||||
scope.WithControlDependencies(std::vector<Operation>{write}), var,
|
||||
DT_INT32);
|
||||
auto d = ops::_Retval(scope.WithOpName("D"), read, 0);
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||
|
||||
// Builds a description of the arguments.
|
||||
std::vector<XlaCompiler::Argument> args(2);
|
||||
args[0].kind = XlaCompiler::Argument::kConstant;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = TensorShape({2});
|
||||
args[0].constant_value = Tensor(DT_INT32, {1, 1});
|
||||
args[0].initialized = true;
|
||||
|
||||
args[1].kind = XlaCompiler::Argument::kResource;
|
||||
args[1].resource_kind = XlaResource::kVariable;
|
||||
args[1].initialized = true;
|
||||
args[1].type = DT_INT32;
|
||||
args[1].shape = TensorShape({2});
|
||||
|
||||
XlaCompiler compiler(DefaultOptions());
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.alias_resource_update = true;
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||
args, &result));
|
||||
|
||||
const xla::HloInputOutputAliasProto& alias =
|
||||
result.computation->proto().input_output_alias();
|
||||
EXPECT_EQ(alias.entries_size(), 1);
|
||||
EXPECT_EQ(alias.entries(0).parameter_number(), 0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user