Automatically set up user aliasing in tf2xla when a resource update is presented.
- When a resource update is presented, automatically alias the input and output. - Also fix an issue where the input/output proto config is overwritten. PiperOrigin-RevId: 294984983 Change-Id: I45e96513dfeaa91f523db63837355b698bd2fb85
This commit is contained in:
parent
469e56eeff
commit
70d8aa322c
@ -242,13 +242,10 @@ Status CreateXlaArgs(const Graph& graph,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void PopulateXlaArgsAndXlaAlias(
|
||||
const tf2xla::Config& config, std::vector<XlaCompiler::Argument>* xla_args,
|
||||
std::vector<xla::XlaBuilder::InputOutputAlias>* xla_aliases) {
|
||||
void PopulateXlaArgs(const tf2xla::Config& config,
|
||||
std::vector<XlaCompiler::Argument>* xla_args) {
|
||||
// Populate arguments with resource variables from the config. The variables
|
||||
// get turned into inputs and outputs.
|
||||
int64 input_num = xla_args->size();
|
||||
int64 output_num = config.fetch_size();
|
||||
for (const tf2xla::Variable& variable : config.variable()) {
|
||||
XlaCompiler::Argument arg;
|
||||
arg.type = variable.type();
|
||||
@ -258,17 +255,9 @@ void PopulateXlaArgsAndXlaAlias(
|
||||
arg.resource_kind = XlaResource::kVariable;
|
||||
arg.initialized = true;
|
||||
xla_args->push_back(std::move(arg));
|
||||
|
||||
if (!variable.readonly()) {
|
||||
// We want to alias the input and output of the variable, so the updates
|
||||
// are carried out in-place.
|
||||
xla_aliases->push_back({/*output_index=*/{output_num},
|
||||
/*param_number=*/input_num, /*param_index=*/{}});
|
||||
++output_num;
|
||||
}
|
||||
++input_num;
|
||||
}
|
||||
}
|
||||
|
||||
Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
std::unique_ptr<Graph>* graph) {
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
|
@ -29,10 +29,9 @@ namespace tensorflow {
|
||||
Status CreateXlaArgs(const Graph& graph,
|
||||
std::vector<XlaCompiler::Argument>* xla_args);
|
||||
|
||||
// Populate xla_args and xla_aliases for the given XLA config.
|
||||
void PopulateXlaArgsAndXlaAlias(
|
||||
const tf2xla::Config& config, std::vector<XlaCompiler::Argument>* xla_args,
|
||||
std::vector<xla::XlaBuilder::InputOutputAlias>* xla_aliases);
|
||||
// Populate xla_args for the given XLA config.
|
||||
void PopulateXlaArgs(const tf2xla::Config& config,
|
||||
std::vector<XlaCompiler::Argument>* xla_args);
|
||||
|
||||
// InitGraph creates a graph based on the graph_def, that may then be converted
|
||||
// to an xla::XlaComputation via ConvertGraphToXla.
|
||||
|
@ -63,9 +63,7 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
|
||||
std::vector<XlaCompiler::Argument> xla_args;
|
||||
TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
|
||||
|
||||
std::vector<xla::XlaBuilder::InputOutputAlias> xla_aliases;
|
||||
PopulateXlaArgsAndXlaAlias(config, &xla_args, &xla_aliases);
|
||||
|
||||
PopulateXlaArgs(config, &xla_args);
|
||||
// Compile the graph into an XLA computation.
|
||||
XlaCompiler::Options compiler_options;
|
||||
compiler_options.client = client;
|
||||
@ -75,12 +73,15 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
|
||||
compiler_options.allow_cpu_custom_calls = true;
|
||||
compiler_options.custom_fake_quant_op_calls =
|
||||
config.conversion_options().custom_fake_quant_op_calls();
|
||||
|
||||
XlaCompiler compiler(compiler_options);
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(),
|
||||
"tfcompile", std::move(graph),
|
||||
xla_args, xla_aliases, &result));
|
||||
|
||||
XlaCompiler::CompileOptions options;
|
||||
options.alias_resource_update = true;
|
||||
TF_RETURN_IF_ERROR(compiler.CompileGraph(
|
||||
options, "tfcompile", std::move(graph), xla_args, &result));
|
||||
*computation = std::move(*result.computation);
|
||||
|
||||
int num_const_results = 0;
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
@ -162,9 +163,9 @@ Status BuildComputation(
|
||||
std::unique_ptr<xla::XlaOp> token_output,
|
||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||
bool is_entry_computation, bool return_updated_values_for_all_resources,
|
||||
bool always_return_tuple, xla::XlaBuilder* builder,
|
||||
xla::XlaComputation* computation, int* num_computation_outputs,
|
||||
int* num_nonconst_outputs,
|
||||
bool always_return_tuple, bool use_tuple_arg, bool alias_resource_update,
|
||||
xla::XlaBuilder* builder, xla::XlaComputation* computation,
|
||||
int* num_computation_outputs, int* num_nonconst_outputs,
|
||||
std::vector<XlaCompiler::OutputDescription>* outputs,
|
||||
std::vector<XlaCompiler::ResourceUpdate>* resource_updates,
|
||||
xla::Shape* output_shape) {
|
||||
@ -284,6 +285,7 @@ Status BuildComputation(
|
||||
!grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
|
||||
arg.tensor_array_gradients.count(grad.first) == 0;
|
||||
}
|
||||
|
||||
if (return_updated_values_for_all_resources || modified) {
|
||||
resource_updates->emplace_back();
|
||||
XlaCompiler::ResourceUpdate& update = resource_updates->back();
|
||||
@ -291,6 +293,21 @@ Status BuildComputation(
|
||||
update.type = resource->type();
|
||||
update.shape = resource->shape();
|
||||
update.modified = modified;
|
||||
if (is_entry_computation && always_return_tuple &&
|
||||
arg.resource_kind != XlaResource::kTensorArray &&
|
||||
alias_resource_update) {
|
||||
// Assuming tuple arg and results are used.
|
||||
int64 output_index = elems.size();
|
||||
if (use_tuple_arg) {
|
||||
builder->SetUpAlias(/*output_index=*/{output_index},
|
||||
/*param_number=*/0,
|
||||
/*param_index=*/{update.input_index});
|
||||
} else {
|
||||
builder->SetUpAlias(/*output_index=*/{output_index},
|
||||
/*param_number=*/update.input_index,
|
||||
/*param_index=*/{});
|
||||
}
|
||||
}
|
||||
for (const auto& grad : resource->tensor_array_gradients()) {
|
||||
update.tensor_array_gradients_accessed.insert(grad.first);
|
||||
}
|
||||
@ -750,7 +767,7 @@ Status XlaCompiler::CompileFunction(
|
||||
|
||||
VLOG(1) << "====================================================";
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileGraph(options, function_id, std::move(graph), args, {}, result));
|
||||
CompileGraph(options, function_id, std::move(graph), args, result));
|
||||
VLOG(1) << "====================================================";
|
||||
|
||||
cache_[{function_id, arg_vector}] = *result;
|
||||
@ -1192,8 +1209,7 @@ Status XlaCompiler::CompileSingleOp(
|
||||
}
|
||||
FixupSourceAndSinkEdges(graph.get());
|
||||
|
||||
return CompileGraph(options, node_def.name(), std::move(graph), args, {},
|
||||
result);
|
||||
return CompileGraph(options, node_def.name(), std::move(graph), args, result);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -1291,7 +1307,6 @@ void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
|
||||
Status XlaCompiler::CompileGraph(
|
||||
const XlaCompiler::CompileOptions& options, string const& name,
|
||||
std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
|
||||
absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,
|
||||
CompilationResult* result) {
|
||||
VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
|
||||
|
||||
@ -1344,12 +1359,6 @@ Status XlaCompiler::CompileGraph(
|
||||
&result->xla_input_shapes, options.is_entry_computation));
|
||||
context->set_args(std::move(arg_expressions));
|
||||
|
||||
// Propagate any aliases given to us by the user.
|
||||
for (const xla::XlaBuilder::InputOutputAlias& alias : user_aliases) {
|
||||
builder.SetUpAlias(alias.output_index, alias.param_number,
|
||||
alias.param_index);
|
||||
}
|
||||
|
||||
PushNodeTokenMapping();
|
||||
// Use std::set instead of std::unordered_set to ensure determinism.
|
||||
std::set<std::string> output_node_token_inputs;
|
||||
@ -1402,7 +1411,8 @@ Status XlaCompiler::CompileGraph(
|
||||
: ShapeRepresentationFn{},
|
||||
options.is_entry_computation,
|
||||
options.return_updated_values_for_all_resources,
|
||||
options.always_return_tuple, &builder, result->computation.get(),
|
||||
options.always_return_tuple, options.use_tuple_arg,
|
||||
options.alias_resource_update, &builder, result->computation.get(),
|
||||
&num_computation_outputs, &num_nonconst_outputs, &result->outputs,
|
||||
&result->resource_updates, &result->xla_output_shape));
|
||||
|
||||
|
@ -213,6 +213,12 @@ class XlaCompiler {
|
||||
|
||||
// True when we should add XLA input & output to the graph/function.
|
||||
bool add_token_input_output = false;
|
||||
|
||||
// Resource updates are converted into input / output of xla. The two
|
||||
// buffers are aliased with other if this option is true.
|
||||
//
|
||||
// Currently only supports TPU.
|
||||
bool alias_resource_update = false;
|
||||
};
|
||||
|
||||
struct OutputDescription {
|
||||
@ -367,7 +373,6 @@ class XlaCompiler {
|
||||
Status CompileGraph(
|
||||
const CompileOptions& options, string const& name,
|
||||
std::unique_ptr<Graph> graph, absl::Span<const Argument> args,
|
||||
absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,
|
||||
CompilationResult* result);
|
||||
|
||||
// Compiles a single Op, given by `node_def`, into an
|
||||
|
@ -183,9 +183,9 @@ TEST_F(XlaCompilerTest, EmptyReturnValues) {
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(
|
||||
XlaCompiler::CompileOptions(), "add", std::move(graph),
|
||||
/*args=*/{}, /*user_aliases=*/{}, &result));
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph),
|
||||
/*args=*/{}, &result));
|
||||
|
||||
TF_ASSERT_OK(client_->Execute(*result.computation, {}).status());
|
||||
}
|
||||
@ -215,8 +215,7 @@ TEST_F(XlaCompilerTest, Simple) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
|
||||
// Tests that the generated computation works.
|
||||
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
@ -267,7 +266,7 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) {
|
||||
compile_options.always_return_tuple = false;
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||
args, /*user_aliases=*/{}, &result));
|
||||
args, &result));
|
||||
|
||||
// Tests that the generated computation works.
|
||||
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
@ -319,8 +318,7 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForUnwrittenResource) {
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.return_updated_values_for_all_resources = true;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||
args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
args, &result));
|
||||
xla::Shape transposed =
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
|
||||
// Check that the return shapes are correctly tranposed.
|
||||
@ -366,8 +364,7 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForFastMemVar) {
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.return_updated_values_for_all_resources = true;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||
args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
args, &result));
|
||||
EXPECT_EQ(fast_mem_arg_count, 1);
|
||||
}
|
||||
|
||||
@ -414,8 +411,7 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
xla::Shape transposed =
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
|
||||
// Check that the return shapes are correctly tranposed.
|
||||
@ -456,8 +452,7 @@ TEST_F(XlaCompilerTest, TransposeVariables) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
xla::Shape transposed =
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0});
|
||||
// Check that the return shapes are correctly tranposed.
|
||||
@ -507,7 +502,7 @@ TEST_F(XlaCompilerTest, MixedOrderArguments) {
|
||||
compile_options.always_return_tuple = false;
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||
args, /*user_aliases=*/{}, &result));
|
||||
args, &result));
|
||||
|
||||
EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1));
|
||||
}
|
||||
@ -537,9 +532,9 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
|
||||
XlaCompiler compiler(DefaultOptions());
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
Status status = compiler.CompileGraph(XlaCompiler::CompileOptions(),
|
||||
"reshape", std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result);
|
||||
Status status =
|
||||
compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape",
|
||||
std::move(graph), args, &result);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_TRUE(
|
||||
absl::StrContains(status.error_message(), "depends on a parameter"))
|
||||
@ -581,8 +576,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
|
||||
std::move(graph_copy), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph_copy), args, &result));
|
||||
|
||||
ASSERT_EQ(2, result.outputs.size());
|
||||
EXPECT_FALSE(result.outputs[0].is_constant);
|
||||
@ -667,8 +661,7 @@ TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) {
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
|
||||
ASSERT_EQ(2, result.outputs.size());
|
||||
EXPECT_FALSE(result.outputs[1].is_constant);
|
||||
@ -707,8 +700,7 @@ TEST_F(XlaCompilerTest, ResourceManager) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
|
||||
EXPECT_EQ(1, resource->Get());
|
||||
|
||||
@ -744,8 +736,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
|
||||
XlaCompiler compiler(options);
|
||||
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &results[i]));
|
||||
std::move(graph), args, &results[i]));
|
||||
}
|
||||
|
||||
for (int64 i = 1; i < test_count; ++i) {
|
||||
@ -811,8 +802,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
|
||||
ASSERT_EQ(1, result.resource_updates.size());
|
||||
const XlaCompiler::ResourceUpdate& update = result.resource_updates[0];
|
||||
@ -871,8 +861,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
|
||||
EXPECT_EQ(0, result.resource_updates.size());
|
||||
}
|
||||
@ -904,8 +893,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
|
||||
EXPECT_EQ(1, result.resource_updates.size());
|
||||
}
|
||||
@ -980,8 +968,7 @@ TEST_F(XlaCompilerTest, FunctionCallWithConstants) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
}
|
||||
|
||||
// Tests CompileFunction with a local function lookup failing, fails with
|
||||
@ -1064,8 +1051,7 @@ TEST_F(XlaCompilerTest, Variables) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
RunAndCheckVariablesComputation(client_, result);
|
||||
}
|
||||
|
||||
@ -1101,7 +1087,7 @@ TEST_F(XlaCompilerTest, ResultLayoutSingle) {
|
||||
auto compile_options = XlaCompiler::CompileOptions();
|
||||
compile_options.always_return_tuple = false;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph),
|
||||
args, /*user_aliases=*/{}, &result));
|
||||
args, &result));
|
||||
EXPECT_TRUE(xla::ShapeUtil::Equal(
|
||||
result.xla_output_shape,
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1})));
|
||||
@ -1138,8 +1124,7 @@ TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
xla::Shape result_shape =
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
|
||||
|
||||
@ -1169,8 +1154,7 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
|
||||
// Tests that the generated computation works.
|
||||
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
@ -1220,8 +1204,7 @@ TEST_F(XlaCompilerTest, ReturnResourceHandle) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
RunAndCheckVariablesComputation(client_, result);
|
||||
}
|
||||
|
||||
@ -1273,7 +1256,7 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||
args, /*user_aliases=*/{}, &result));
|
||||
args, &result));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
|
||||
client_->GetComputationShape(*result.computation));
|
||||
@ -1344,7 +1327,7 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||
args, /*user_aliases=*/{}, &result));
|
||||
args, &result));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
|
||||
client_->GetComputationShape(*result.computation));
|
||||
@ -1425,8 +1408,7 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
XlaCompiler::CompilationResult result;
|
||||
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
|
||||
std::move(graph), args, /*user_aliases=*/{},
|
||||
&result);
|
||||
std::move(graph), args, &result);
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp"))
|
||||
<< status.error_message();
|
||||
@ -1451,8 +1433,7 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
|
||||
XlaCompiler::CompilationResult result;
|
||||
XlaCompiler compiler(DefaultOptions());
|
||||
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
|
||||
std::move(graph), args, /*user_aliases=*/{},
|
||||
&result);
|
||||
std::move(graph), args, &result);
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||
"is not in the list of allowed values"))
|
||||
@ -1478,8 +1459,7 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
|
||||
CopyGraph(*graph, graph_copy.get());
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
|
||||
std::move(graph_copy), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph_copy), args, &result));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1530,7 +1510,7 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) {
|
||||
CopyGraph(*graph, graph_copy.get());
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
|
||||
args, /*user_aliases=*/{}, &result));
|
||||
args, &result));
|
||||
EXPECT_EQ(result.xla_input_shapes.size(), 1);
|
||||
EXPECT_TRUE(result.xla_output_shape.IsTuple());
|
||||
EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
|
||||
@ -1548,7 +1528,7 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) {
|
||||
CopyGraph(*graph, graph_copy.get());
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
|
||||
args, /*user_aliases=*/{}, &result));
|
||||
args, &result));
|
||||
EXPECT_EQ(result.xla_input_shapes.size(), 2);
|
||||
EXPECT_TRUE(result.xla_input_shapes[1].IsToken());
|
||||
EXPECT_TRUE(result.xla_output_shape.IsTuple());
|
||||
@ -1620,8 +1600,7 @@ TEST_F(XlaCompilerTest, OpsWithTensorListInput) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
ASSERT_EQ(result.outputs.size(), 2);
|
||||
const XlaCompiler::OutputDescription& output0 = result.outputs[0];
|
||||
ASSERT_TRUE(output0.is_tensor_list);
|
||||
@ -1710,8 +1689,7 @@ TEST_F(XlaCompilerTest, WhileWithResources) {
|
||||
compile_options.return_updated_values_for_all_resources = true;
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "tested_while_with_vars",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
ASSERT_EQ(result.outputs.size(), 3);
|
||||
const XlaCompiler::OutputDescription& output1 = result.outputs[1];
|
||||
ASSERT_EQ(output1.input_index, 1);
|
||||
@ -1772,8 +1750,7 @@ TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "test",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
std::move(graph), args, &result));
|
||||
|
||||
// Tests that we set sharding on the root TUPLE instruction.
|
||||
const auto& hlo_module_proto = result.computation->proto();
|
||||
@ -1829,8 +1806,8 @@ TEST_F(XlaCompilerTest, DoNotConstantFoldShapeOp) {
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
auto options = XlaCompiler::CompileOptions();
|
||||
TF_ASSERT_OK(compiler.CompileGraph(options, "test", std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
TF_ASSERT_OK(
|
||||
compiler.CompileGraph(options, "test", std::move(graph), args, &result));
|
||||
|
||||
xla::Literal literal0 =
|
||||
xla::LiteralUtil::CreateR2<int32>({{0, 1, 2}, {3, 4, 5}});
|
||||
|
@ -332,6 +332,12 @@ class XlaBuilder {
|
||||
// Adds a new input/output alias. Since the input/output shape information are
|
||||
// not available until the computation is built, and eventual error in the
|
||||
// arguments of this API will be detected only at computation Build() time.
|
||||
//
|
||||
// Note: Aliasing API is 'may-alias' and only donated buffer at runtime will
|
||||
// be aliased with output. If a buffer is not donated at runtime, a copy will
|
||||
// be inserted by XLA to prevent buffer clobbering.
|
||||
//
|
||||
// Only works on TPU backend.
|
||||
void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& param_index) {
|
||||
input_output_aliases_.push_back({output_index, param_number, param_index});
|
||||
|
@ -335,7 +335,7 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
|
||||
|
||||
if (parent() && parent()->has_entry_computation() &&
|
||||
parent()->entry_computation() == this) {
|
||||
if (!Shape::Equal()(new_root_instruction->shape(),
|
||||
if (!Shape::Equal().IgnoreLayout()(new_root_instruction->shape(),
|
||||
root_instruction_->shape())) {
|
||||
// Rebuild input output alias config now that we have a new output shape.
|
||||
parent()->input_output_alias_config() =
|
||||
|
@ -64,7 +64,8 @@ class HloInputOutputAliasConfig {
|
||||
// Sets up alias config from `output_index` to `param_index` at
|
||||
// `param_number`.
|
||||
Status SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& param_index, AliasKind kind);
|
||||
const ShapeIndex& param_index,
|
||||
AliasKind kind = AliasKind::kUserAlias);
|
||||
|
||||
// Returns the kind of alias for the given parameter number and parameter
|
||||
// index. If no alias exists, AliasKind::kNoAlias is returned.
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/stacktrace.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
@ -1780,8 +1780,12 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(
|
||||
*module, [this](const Shape& shape) {
|
||||
*module, [this](const Shape& shape) -> int64 {
|
||||
if (target_metadata_->IsLayoutSensitive()) {
|
||||
return target_metadata_->ShapeSize(shape);
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}));
|
||||
|
||||
TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module));
|
||||
|
@ -214,6 +214,8 @@ class TargetVerifierMetadata {
|
||||
|
||||
virtual std::unique_ptr<ShapeVerifier> GetVerifier() const = 0;
|
||||
|
||||
virtual bool IsLayoutSensitive() const = 0;
|
||||
|
||||
TargetVerifierMetadata() {}
|
||||
virtual ~TargetVerifierMetadata() {}
|
||||
|
||||
@ -245,6 +247,8 @@ class DefaultVerifierMetadata : public TargetVerifierMetadata {
|
||||
layout_sensitive_, allow_mixed_precision_, shape_size_function_);
|
||||
}
|
||||
|
||||
bool IsLayoutSensitive() const override { return layout_sensitive_; }
|
||||
|
||||
private:
|
||||
bool layout_sensitive_;
|
||||
bool allow_mixed_precision_;
|
||||
|
@ -215,8 +215,12 @@ TEST_F(BufferDonationTest, SimpleWhileTupleTest) {
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(f32v1_, while0, 1));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||
|
||||
module->AddEntryComputation(builder.Build());
|
||||
// Input output aliasing is only supported on TPU.
|
||||
#if defined(XLA_TEST_BACKEND_TPU)
|
||||
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0}));
|
||||
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1}));
|
||||
#endif
|
||||
|
||||
auto arg = LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR0<int>(0), LiteralUtil::CreateR1<float>({1.1f})});
|
||||
|
@ -1477,6 +1477,105 @@ TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) {
|
||||
EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
|
||||
}
|
||||
|
||||
TEST(RawApiTest, CompileAndExecuteWithReusedBuffersS64) {
|
||||
xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::S64, {2});
|
||||
xla::Shape shape =
|
||||
xla::ShapeUtil::MakeTupleShape({element_shape, element_shape});
|
||||
xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape(
|
||||
{element_shape, element_shape, element_shape, element_shape});
|
||||
xla::XlaBuilder builder("ReuseBuffer");
|
||||
auto param = xla::Parameter(&builder, 0, shape, "param");
|
||||
auto p0 = xla::GetTupleElement(param, 0);
|
||||
auto p1 = xla::GetTupleElement(param, 1);
|
||||
auto add = xla::Add(p0, p1);
|
||||
auto sub = xla::Sub(p0, p1);
|
||||
xla::Tuple(&builder, {add, sub, p0, p1});
|
||||
|
||||
// Flip the tuple literals in the input handle.
|
||||
builder.SetUpAlias({1}, 0, {0});
|
||||
builder.SetUpAlias({0}, 0, {1});
|
||||
|
||||
auto computation = builder.Build().ValueOrDie();
|
||||
|
||||
auto literal0 = xla::LiteralUtil::CreateR1<int64>({1, 2});
|
||||
auto literal1 = xla::LiteralUtil::CreateR1<int64>({5, 9});
|
||||
auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
|
||||
|
||||
xrt::XLAAllocation param_alloc;
|
||||
*param_alloc.mutable_value() = literal.ToProto();
|
||||
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
*shapes->add_parameters() = shape.ToProto();
|
||||
*shapes->mutable_result() = return_shape.ToProto();
|
||||
StoreComputationSnapshot(computation, c.mutable_hlo_snapshot());
|
||||
|
||||
xrt::XRTExecutionConfig e;
|
||||
e.set_release_input_handles(false);
|
||||
e.set_release_compilation_handle(true);
|
||||
|
||||
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||
XrtClientSession session(root);
|
||||
auto e_config =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
|
||||
auto c_data =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
|
||||
auto c_handle = ops::XRTCompile(root, c_data);
|
||||
auto param_value = ops::Const(root.WithDevice("/device:CPU:0"),
|
||||
param_alloc.SerializeAsString());
|
||||
auto param_handle = ops::XRTAllocate(root, param_value);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
std::vector<Tensor> outputs;
|
||||
TF_EXPECT_OK(session.Run({param_handle}, &outputs));
|
||||
|
||||
int64 alloc_handle = outputs[0].scalar<int64>()();
|
||||
|
||||
// Note that we release the result handle immediately, but since we aliased
|
||||
// the output buffers onto the input allocation ones (held in alloc_handle),
|
||||
// we can fetch the result from there.
|
||||
auto result =
|
||||
ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)});
|
||||
auto read_back = ops::XRTReadLiteral(root, result);
|
||||
auto release = ops::XRTReleaseAllocationHandle(
|
||||
root.WithControlDependencies(read_back), result);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
TF_EXPECT_OK(
|
||||
session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
|
||||
|
||||
xla::Literal exec_literal = ReadOutputLiteral(outputs, 0);
|
||||
auto exec_literal_parts = exec_literal.DecomposeTuple();
|
||||
ASSERT_EQ(exec_literal_parts.size(), 4);
|
||||
|
||||
EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0));
|
||||
EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1));
|
||||
|
||||
// Now we read back the original input handle values, which at this point
|
||||
// should contain the result of the XLA computation.
|
||||
auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle));
|
||||
TF_ASSERT_OK(root.status());
|
||||
auto release_handle = ops::XRTReleaseAllocationHandle(
|
||||
root.WithControlDependencies(read_handle), Input(alloc_handle));
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle},
|
||||
{release_handle}, &outputs));
|
||||
|
||||
xla::Literal return_literal = ReadOutputLiteral(outputs, 0);
|
||||
|
||||
auto expected_literal0 = xla::LiteralUtil::CreateR1<int64>({6, 11});
|
||||
auto expected_literal1 = xla::LiteralUtil::CreateR1<int64>({-4, -7});
|
||||
// The first element of the computation returned tuple would be the add
|
||||
// (expected_literal0), but since we flipped the buffers, the sub
|
||||
// (expected_literal1) should come first.
|
||||
auto expected_literal =
|
||||
xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0});
|
||||
|
||||
EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
|
||||
}
|
||||
|
||||
TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
|
||||
xrt::XLAAllocation p0;
|
||||
*p0.mutable_value() = xla::LiteralUtil::CreateR0<int64>(11031965).ToProto();
|
||||
|
Loading…
Reference in New Issue
Block a user