Automatically set up user aliasing in tf2xla when a resource update is presented.
- When a resource update is presented, automatically alias the input and output. - Also fix an issue where the input/output proto config is overwritten. PiperOrigin-RevId: 294984983 Change-Id: I45e96513dfeaa91f523db63837355b698bd2fb85
This commit is contained in:
parent
469e56eeff
commit
70d8aa322c
@ -242,13 +242,10 @@ Status CreateXlaArgs(const Graph& graph,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void PopulateXlaArgsAndXlaAlias(
|
void PopulateXlaArgs(const tf2xla::Config& config,
|
||||||
const tf2xla::Config& config, std::vector<XlaCompiler::Argument>* xla_args,
|
std::vector<XlaCompiler::Argument>* xla_args) {
|
||||||
std::vector<xla::XlaBuilder::InputOutputAlias>* xla_aliases) {
|
|
||||||
// Populate arguments with resource variables from the config. The variables
|
// Populate arguments with resource variables from the config. The variables
|
||||||
// get turned into inputs and outputs.
|
// get turned into inputs and outputs.
|
||||||
int64 input_num = xla_args->size();
|
|
||||||
int64 output_num = config.fetch_size();
|
|
||||||
for (const tf2xla::Variable& variable : config.variable()) {
|
for (const tf2xla::Variable& variable : config.variable()) {
|
||||||
XlaCompiler::Argument arg;
|
XlaCompiler::Argument arg;
|
||||||
arg.type = variable.type();
|
arg.type = variable.type();
|
||||||
@ -258,17 +255,9 @@ void PopulateXlaArgsAndXlaAlias(
|
|||||||
arg.resource_kind = XlaResource::kVariable;
|
arg.resource_kind = XlaResource::kVariable;
|
||||||
arg.initialized = true;
|
arg.initialized = true;
|
||||||
xla_args->push_back(std::move(arg));
|
xla_args->push_back(std::move(arg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!variable.readonly()) {
|
|
||||||
// We want to alias the input and output of the variable, so the updates
|
|
||||||
// are carried out in-place.
|
|
||||||
xla_aliases->push_back({/*output_index=*/{output_num},
|
|
||||||
/*param_number=*/input_num, /*param_index=*/{}});
|
|
||||||
++output_num;
|
|
||||||
}
|
|
||||||
++input_num;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||||
std::unique_ptr<Graph>* graph) {
|
std::unique_ptr<Graph>* graph) {
|
||||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||||
|
@ -29,10 +29,9 @@ namespace tensorflow {
|
|||||||
Status CreateXlaArgs(const Graph& graph,
|
Status CreateXlaArgs(const Graph& graph,
|
||||||
std::vector<XlaCompiler::Argument>* xla_args);
|
std::vector<XlaCompiler::Argument>* xla_args);
|
||||||
|
|
||||||
// Populate xla_args and xla_aliases for the given XLA config.
|
// Populate xla_args for the given XLA config.
|
||||||
void PopulateXlaArgsAndXlaAlias(
|
void PopulateXlaArgs(const tf2xla::Config& config,
|
||||||
const tf2xla::Config& config, std::vector<XlaCompiler::Argument>* xla_args,
|
std::vector<XlaCompiler::Argument>* xla_args);
|
||||||
std::vector<xla::XlaBuilder::InputOutputAlias>* xla_aliases);
|
|
||||||
|
|
||||||
// InitGraph creates a graph based on the graph_def, that may then be converted
|
// InitGraph creates a graph based on the graph_def, that may then be converted
|
||||||
// to an xla::XlaComputation via ConvertGraphToXla.
|
// to an xla::XlaComputation via ConvertGraphToXla.
|
||||||
|
@ -63,9 +63,7 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
|
|||||||
std::vector<XlaCompiler::Argument> xla_args;
|
std::vector<XlaCompiler::Argument> xla_args;
|
||||||
TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
|
TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
|
||||||
|
|
||||||
std::vector<xla::XlaBuilder::InputOutputAlias> xla_aliases;
|
PopulateXlaArgs(config, &xla_args);
|
||||||
PopulateXlaArgsAndXlaAlias(config, &xla_args, &xla_aliases);
|
|
||||||
|
|
||||||
// Compile the graph into an XLA computation.
|
// Compile the graph into an XLA computation.
|
||||||
XlaCompiler::Options compiler_options;
|
XlaCompiler::Options compiler_options;
|
||||||
compiler_options.client = client;
|
compiler_options.client = client;
|
||||||
@ -75,12 +73,15 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
|
|||||||
compiler_options.allow_cpu_custom_calls = true;
|
compiler_options.allow_cpu_custom_calls = true;
|
||||||
compiler_options.custom_fake_quant_op_calls =
|
compiler_options.custom_fake_quant_op_calls =
|
||||||
config.conversion_options().custom_fake_quant_op_calls();
|
config.conversion_options().custom_fake_quant_op_calls();
|
||||||
|
|
||||||
XlaCompiler compiler(compiler_options);
|
XlaCompiler compiler(compiler_options);
|
||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(),
|
|
||||||
"tfcompile", std::move(graph),
|
XlaCompiler::CompileOptions options;
|
||||||
xla_args, xla_aliases, &result));
|
options.alias_resource_update = true;
|
||||||
|
TF_RETURN_IF_ERROR(compiler.CompileGraph(
|
||||||
|
options, "tfcompile", std::move(graph), xla_args, &result));
|
||||||
*computation = std::move(*result.computation);
|
*computation = std::move(*result.computation);
|
||||||
|
|
||||||
int num_const_results = 0;
|
int num_const_results = 0;
|
||||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/executor.h"
|
#include "tensorflow/core/common_runtime/executor.h"
|
||||||
@ -162,9 +163,9 @@ Status BuildComputation(
|
|||||||
std::unique_ptr<xla::XlaOp> token_output,
|
std::unique_ptr<xla::XlaOp> token_output,
|
||||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||||
bool is_entry_computation, bool return_updated_values_for_all_resources,
|
bool is_entry_computation, bool return_updated_values_for_all_resources,
|
||||||
bool always_return_tuple, xla::XlaBuilder* builder,
|
bool always_return_tuple, bool use_tuple_arg, bool alias_resource_update,
|
||||||
xla::XlaComputation* computation, int* num_computation_outputs,
|
xla::XlaBuilder* builder, xla::XlaComputation* computation,
|
||||||
int* num_nonconst_outputs,
|
int* num_computation_outputs, int* num_nonconst_outputs,
|
||||||
std::vector<XlaCompiler::OutputDescription>* outputs,
|
std::vector<XlaCompiler::OutputDescription>* outputs,
|
||||||
std::vector<XlaCompiler::ResourceUpdate>* resource_updates,
|
std::vector<XlaCompiler::ResourceUpdate>* resource_updates,
|
||||||
xla::Shape* output_shape) {
|
xla::Shape* output_shape) {
|
||||||
@ -284,6 +285,7 @@ Status BuildComputation(
|
|||||||
!grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
|
!grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
|
||||||
arg.tensor_array_gradients.count(grad.first) == 0;
|
arg.tensor_array_gradients.count(grad.first) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (return_updated_values_for_all_resources || modified) {
|
if (return_updated_values_for_all_resources || modified) {
|
||||||
resource_updates->emplace_back();
|
resource_updates->emplace_back();
|
||||||
XlaCompiler::ResourceUpdate& update = resource_updates->back();
|
XlaCompiler::ResourceUpdate& update = resource_updates->back();
|
||||||
@ -291,6 +293,21 @@ Status BuildComputation(
|
|||||||
update.type = resource->type();
|
update.type = resource->type();
|
||||||
update.shape = resource->shape();
|
update.shape = resource->shape();
|
||||||
update.modified = modified;
|
update.modified = modified;
|
||||||
|
if (is_entry_computation && always_return_tuple &&
|
||||||
|
arg.resource_kind != XlaResource::kTensorArray &&
|
||||||
|
alias_resource_update) {
|
||||||
|
// Assuming tuple arg and results are used.
|
||||||
|
int64 output_index = elems.size();
|
||||||
|
if (use_tuple_arg) {
|
||||||
|
builder->SetUpAlias(/*output_index=*/{output_index},
|
||||||
|
/*param_number=*/0,
|
||||||
|
/*param_index=*/{update.input_index});
|
||||||
|
} else {
|
||||||
|
builder->SetUpAlias(/*output_index=*/{output_index},
|
||||||
|
/*param_number=*/update.input_index,
|
||||||
|
/*param_index=*/{});
|
||||||
|
}
|
||||||
|
}
|
||||||
for (const auto& grad : resource->tensor_array_gradients()) {
|
for (const auto& grad : resource->tensor_array_gradients()) {
|
||||||
update.tensor_array_gradients_accessed.insert(grad.first);
|
update.tensor_array_gradients_accessed.insert(grad.first);
|
||||||
}
|
}
|
||||||
@ -750,7 +767,7 @@ Status XlaCompiler::CompileFunction(
|
|||||||
|
|
||||||
VLOG(1) << "====================================================";
|
VLOG(1) << "====================================================";
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
CompileGraph(options, function_id, std::move(graph), args, {}, result));
|
CompileGraph(options, function_id, std::move(graph), args, result));
|
||||||
VLOG(1) << "====================================================";
|
VLOG(1) << "====================================================";
|
||||||
|
|
||||||
cache_[{function_id, arg_vector}] = *result;
|
cache_[{function_id, arg_vector}] = *result;
|
||||||
@ -1192,8 +1209,7 @@ Status XlaCompiler::CompileSingleOp(
|
|||||||
}
|
}
|
||||||
FixupSourceAndSinkEdges(graph.get());
|
FixupSourceAndSinkEdges(graph.get());
|
||||||
|
|
||||||
return CompileGraph(options, node_def.name(), std::move(graph), args, {},
|
return CompileGraph(options, node_def.name(), std::move(graph), args, result);
|
||||||
result);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -1291,7 +1307,6 @@ void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
|
|||||||
Status XlaCompiler::CompileGraph(
|
Status XlaCompiler::CompileGraph(
|
||||||
const XlaCompiler::CompileOptions& options, string const& name,
|
const XlaCompiler::CompileOptions& options, string const& name,
|
||||||
std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
|
std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
|
||||||
absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,
|
|
||||||
CompilationResult* result) {
|
CompilationResult* result) {
|
||||||
VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
|
VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
|
||||||
|
|
||||||
@ -1344,12 +1359,6 @@ Status XlaCompiler::CompileGraph(
|
|||||||
&result->xla_input_shapes, options.is_entry_computation));
|
&result->xla_input_shapes, options.is_entry_computation));
|
||||||
context->set_args(std::move(arg_expressions));
|
context->set_args(std::move(arg_expressions));
|
||||||
|
|
||||||
// Propagate any aliases given to us by the user.
|
|
||||||
for (const xla::XlaBuilder::InputOutputAlias& alias : user_aliases) {
|
|
||||||
builder.SetUpAlias(alias.output_index, alias.param_number,
|
|
||||||
alias.param_index);
|
|
||||||
}
|
|
||||||
|
|
||||||
PushNodeTokenMapping();
|
PushNodeTokenMapping();
|
||||||
// Use std::set instead of std::unordered_set to ensure determinism.
|
// Use std::set instead of std::unordered_set to ensure determinism.
|
||||||
std::set<std::string> output_node_token_inputs;
|
std::set<std::string> output_node_token_inputs;
|
||||||
@ -1402,7 +1411,8 @@ Status XlaCompiler::CompileGraph(
|
|||||||
: ShapeRepresentationFn{},
|
: ShapeRepresentationFn{},
|
||||||
options.is_entry_computation,
|
options.is_entry_computation,
|
||||||
options.return_updated_values_for_all_resources,
|
options.return_updated_values_for_all_resources,
|
||||||
options.always_return_tuple, &builder, result->computation.get(),
|
options.always_return_tuple, options.use_tuple_arg,
|
||||||
|
options.alias_resource_update, &builder, result->computation.get(),
|
||||||
&num_computation_outputs, &num_nonconst_outputs, &result->outputs,
|
&num_computation_outputs, &num_nonconst_outputs, &result->outputs,
|
||||||
&result->resource_updates, &result->xla_output_shape));
|
&result->resource_updates, &result->xla_output_shape));
|
||||||
|
|
||||||
|
@ -213,6 +213,12 @@ class XlaCompiler {
|
|||||||
|
|
||||||
// True when we should add XLA input & output to the graph/function.
|
// True when we should add XLA input & output to the graph/function.
|
||||||
bool add_token_input_output = false;
|
bool add_token_input_output = false;
|
||||||
|
|
||||||
|
// Resource updates are converted into input / output of xla. The two
|
||||||
|
// buffers are aliased with other if this option is true.
|
||||||
|
//
|
||||||
|
// Currently only supports TPU.
|
||||||
|
bool alias_resource_update = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct OutputDescription {
|
struct OutputDescription {
|
||||||
@ -367,7 +373,6 @@ class XlaCompiler {
|
|||||||
Status CompileGraph(
|
Status CompileGraph(
|
||||||
const CompileOptions& options, string const& name,
|
const CompileOptions& options, string const& name,
|
||||||
std::unique_ptr<Graph> graph, absl::Span<const Argument> args,
|
std::unique_ptr<Graph> graph, absl::Span<const Argument> args,
|
||||||
absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,
|
|
||||||
CompilationResult* result);
|
CompilationResult* result);
|
||||||
|
|
||||||
// Compiles a single Op, given by `node_def`, into an
|
// Compiles a single Op, given by `node_def`, into an
|
||||||
|
@ -183,9 +183,9 @@ TEST_F(XlaCompilerTest, EmptyReturnValues) {
|
|||||||
|
|
||||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||||
XlaCompiler::CompileOptions(), "add", std::move(graph),
|
std::move(graph),
|
||||||
/*args=*/{}, /*user_aliases=*/{}, &result));
|
/*args=*/{}, &result));
|
||||||
|
|
||||||
TF_ASSERT_OK(client_->Execute(*result.computation, {}).status());
|
TF_ASSERT_OK(client_->Execute(*result.computation, {}).status());
|
||||||
}
|
}
|
||||||
@ -215,8 +215,7 @@ TEST_F(XlaCompilerTest, Simple) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
|
|
||||||
// Tests that the generated computation works.
|
// Tests that the generated computation works.
|
||||||
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
|
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||||
@ -267,7 +266,7 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) {
|
|||||||
compile_options.always_return_tuple = false;
|
compile_options.always_return_tuple = false;
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||||
args, /*user_aliases=*/{}, &result));
|
args, &result));
|
||||||
|
|
||||||
// Tests that the generated computation works.
|
// Tests that the generated computation works.
|
||||||
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
|
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||||
@ -319,8 +318,7 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForUnwrittenResource) {
|
|||||||
XlaCompiler::CompileOptions compile_options;
|
XlaCompiler::CompileOptions compile_options;
|
||||||
compile_options.return_updated_values_for_all_resources = true;
|
compile_options.return_updated_values_for_all_resources = true;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||||
args,
|
args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
xla::Shape transposed =
|
xla::Shape transposed =
|
||||||
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
|
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
|
||||||
// Check that the return shapes are correctly tranposed.
|
// Check that the return shapes are correctly tranposed.
|
||||||
@ -366,8 +364,7 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForFastMemVar) {
|
|||||||
XlaCompiler::CompileOptions compile_options;
|
XlaCompiler::CompileOptions compile_options;
|
||||||
compile_options.return_updated_values_for_all_resources = true;
|
compile_options.return_updated_values_for_all_resources = true;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||||
args,
|
args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
EXPECT_EQ(fast_mem_arg_count, 1);
|
EXPECT_EQ(fast_mem_arg_count, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -414,8 +411,7 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
xla::Shape transposed =
|
xla::Shape transposed =
|
||||||
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
|
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
|
||||||
// Check that the return shapes are correctly tranposed.
|
// Check that the return shapes are correctly tranposed.
|
||||||
@ -456,8 +452,7 @@ TEST_F(XlaCompilerTest, TransposeVariables) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
xla::Shape transposed =
|
xla::Shape transposed =
|
||||||
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0});
|
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0});
|
||||||
// Check that the return shapes are correctly tranposed.
|
// Check that the return shapes are correctly tranposed.
|
||||||
@ -507,7 +502,7 @@ TEST_F(XlaCompilerTest, MixedOrderArguments) {
|
|||||||
compile_options.always_return_tuple = false;
|
compile_options.always_return_tuple = false;
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||||
args, /*user_aliases=*/{}, &result));
|
args, &result));
|
||||||
|
|
||||||
EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1));
|
EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1));
|
||||||
}
|
}
|
||||||
@ -537,9 +532,9 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
|
|||||||
XlaCompiler compiler(DefaultOptions());
|
XlaCompiler compiler(DefaultOptions());
|
||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
Status status = compiler.CompileGraph(XlaCompiler::CompileOptions(),
|
Status status =
|
||||||
"reshape", std::move(graph), args,
|
compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape",
|
||||||
/*user_aliases=*/{}, &result);
|
std::move(graph), args, &result);
|
||||||
EXPECT_FALSE(status.ok());
|
EXPECT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
absl::StrContains(status.error_message(), "depends on a parameter"))
|
absl::StrContains(status.error_message(), "depends on a parameter"))
|
||||||
@ -581,8 +576,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
|||||||
XlaCompiler::CompileOptions compile_options;
|
XlaCompiler::CompileOptions compile_options;
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
|
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
|
||||||
std::move(graph_copy), args,
|
std::move(graph_copy), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
|
|
||||||
ASSERT_EQ(2, result.outputs.size());
|
ASSERT_EQ(2, result.outputs.size());
|
||||||
EXPECT_FALSE(result.outputs[0].is_constant);
|
EXPECT_FALSE(result.outputs[0].is_constant);
|
||||||
@ -667,8 +661,7 @@ TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) {
|
|||||||
XlaCompiler::CompileOptions compile_options;
|
XlaCompiler::CompileOptions compile_options;
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
|
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
|
|
||||||
ASSERT_EQ(2, result.outputs.size());
|
ASSERT_EQ(2, result.outputs.size());
|
||||||
EXPECT_FALSE(result.outputs[1].is_constant);
|
EXPECT_FALSE(result.outputs[1].is_constant);
|
||||||
@ -707,8 +700,7 @@ TEST_F(XlaCompilerTest, ResourceManager) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
|
|
||||||
EXPECT_EQ(1, resource->Get());
|
EXPECT_EQ(1, resource->Get());
|
||||||
|
|
||||||
@ -744,8 +736,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
|
|||||||
XlaCompiler compiler(options);
|
XlaCompiler compiler(options);
|
||||||
|
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &results[i]));
|
||||||
/*user_aliases=*/{}, &results[i]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64 i = 1; i < test_count; ++i) {
|
for (int64 i = 1; i < test_count; ++i) {
|
||||||
@ -811,8 +802,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
|
|
||||||
ASSERT_EQ(1, result.resource_updates.size());
|
ASSERT_EQ(1, result.resource_updates.size());
|
||||||
const XlaCompiler::ResourceUpdate& update = result.resource_updates[0];
|
const XlaCompiler::ResourceUpdate& update = result.resource_updates[0];
|
||||||
@ -871,8 +861,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
|
|
||||||
EXPECT_EQ(0, result.resource_updates.size());
|
EXPECT_EQ(0, result.resource_updates.size());
|
||||||
}
|
}
|
||||||
@ -904,8 +893,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
|
|
||||||
EXPECT_EQ(1, result.resource_updates.size());
|
EXPECT_EQ(1, result.resource_updates.size());
|
||||||
}
|
}
|
||||||
@ -980,8 +968,7 @@ TEST_F(XlaCompilerTest, FunctionCallWithConstants) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests CompileFunction with a local function lookup failing, fails with
|
// Tests CompileFunction with a local function lookup failing, fails with
|
||||||
@ -1064,8 +1051,7 @@ TEST_F(XlaCompilerTest, Variables) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
RunAndCheckVariablesComputation(client_, result);
|
RunAndCheckVariablesComputation(client_, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1101,7 +1087,7 @@ TEST_F(XlaCompilerTest, ResultLayoutSingle) {
|
|||||||
auto compile_options = XlaCompiler::CompileOptions();
|
auto compile_options = XlaCompiler::CompileOptions();
|
||||||
compile_options.always_return_tuple = false;
|
compile_options.always_return_tuple = false;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph),
|
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph),
|
||||||
args, /*user_aliases=*/{}, &result));
|
args, &result));
|
||||||
EXPECT_TRUE(xla::ShapeUtil::Equal(
|
EXPECT_TRUE(xla::ShapeUtil::Equal(
|
||||||
result.xla_output_shape,
|
result.xla_output_shape,
|
||||||
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1})));
|
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1})));
|
||||||
@ -1138,8 +1124,7 @@ TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
xla::Shape result_shape =
|
xla::Shape result_shape =
|
||||||
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
|
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
|
||||||
|
|
||||||
@ -1169,8 +1154,7 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
|
|
||||||
// Tests that the generated computation works.
|
// Tests that the generated computation works.
|
||||||
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||||
@ -1220,8 +1204,7 @@ TEST_F(XlaCompilerTest, ReturnResourceHandle) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
RunAndCheckVariablesComputation(client_, result);
|
RunAndCheckVariablesComputation(client_, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1273,7 +1256,7 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||||
args, /*user_aliases=*/{}, &result));
|
args, &result));
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
|
||||||
client_->GetComputationShape(*result.computation));
|
client_->GetComputationShape(*result.computation));
|
||||||
@ -1344,7 +1327,7 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||||
args, /*user_aliases=*/{}, &result));
|
args, &result));
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
|
||||||
client_->GetComputationShape(*result.computation));
|
client_->GetComputationShape(*result.computation));
|
||||||
@ -1425,8 +1408,7 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
|
|||||||
std::vector<XlaCompiler::Argument> args;
|
std::vector<XlaCompiler::Argument> args;
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
|
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
|
||||||
std::move(graph), args, /*user_aliases=*/{},
|
std::move(graph), args, &result);
|
||||||
&result);
|
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp"))
|
EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp"))
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
@ -1451,8 +1433,7 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
|
|||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
XlaCompiler compiler(DefaultOptions());
|
XlaCompiler compiler(DefaultOptions());
|
||||||
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
|
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
|
||||||
std::move(graph), args, /*user_aliases=*/{},
|
std::move(graph), args, &result);
|
||||||
&result);
|
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
"is not in the list of allowed values"))
|
"is not in the list of allowed values"))
|
||||||
@ -1478,8 +1459,7 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
|
|||||||
CopyGraph(*graph, graph_copy.get());
|
CopyGraph(*graph, graph_copy.get());
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
|
||||||
std::move(graph_copy), args,
|
std::move(graph_copy), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1530,7 +1510,7 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) {
|
|||||||
CopyGraph(*graph, graph_copy.get());
|
CopyGraph(*graph, graph_copy.get());
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
|
TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
|
||||||
args, /*user_aliases=*/{}, &result));
|
args, &result));
|
||||||
EXPECT_EQ(result.xla_input_shapes.size(), 1);
|
EXPECT_EQ(result.xla_input_shapes.size(), 1);
|
||||||
EXPECT_TRUE(result.xla_output_shape.IsTuple());
|
EXPECT_TRUE(result.xla_output_shape.IsTuple());
|
||||||
EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
|
EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
|
||||||
@ -1548,7 +1528,7 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) {
|
|||||||
CopyGraph(*graph, graph_copy.get());
|
CopyGraph(*graph, graph_copy.get());
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
|
TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
|
||||||
args, /*user_aliases=*/{}, &result));
|
args, &result));
|
||||||
EXPECT_EQ(result.xla_input_shapes.size(), 2);
|
EXPECT_EQ(result.xla_input_shapes.size(), 2);
|
||||||
EXPECT_TRUE(result.xla_input_shapes[1].IsToken());
|
EXPECT_TRUE(result.xla_input_shapes[1].IsToken());
|
||||||
EXPECT_TRUE(result.xla_output_shape.IsTuple());
|
EXPECT_TRUE(result.xla_output_shape.IsTuple());
|
||||||
@ -1620,8 +1600,7 @@ TEST_F(XlaCompilerTest, OpsWithTensorListInput) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
ASSERT_EQ(result.outputs.size(), 2);
|
ASSERT_EQ(result.outputs.size(), 2);
|
||||||
const XlaCompiler::OutputDescription& output0 = result.outputs[0];
|
const XlaCompiler::OutputDescription& output0 = result.outputs[0];
|
||||||
ASSERT_TRUE(output0.is_tensor_list);
|
ASSERT_TRUE(output0.is_tensor_list);
|
||||||
@ -1710,8 +1689,7 @@ TEST_F(XlaCompilerTest, WhileWithResources) {
|
|||||||
compile_options.return_updated_values_for_all_resources = true;
|
compile_options.return_updated_values_for_all_resources = true;
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "tested_while_with_vars",
|
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "tested_while_with_vars",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
ASSERT_EQ(result.outputs.size(), 3);
|
ASSERT_EQ(result.outputs.size(), 3);
|
||||||
const XlaCompiler::OutputDescription& output1 = result.outputs[1];
|
const XlaCompiler::OutputDescription& output1 = result.outputs[1];
|
||||||
ASSERT_EQ(output1.input_index, 1);
|
ASSERT_EQ(output1.input_index, 1);
|
||||||
@ -1772,8 +1750,7 @@ TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "test",
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "test",
|
||||||
std::move(graph), args,
|
std::move(graph), args, &result));
|
||||||
/*user_aliases=*/{}, &result));
|
|
||||||
|
|
||||||
// Tests that we set sharding on the root TUPLE instruction.
|
// Tests that we set sharding on the root TUPLE instruction.
|
||||||
const auto& hlo_module_proto = result.computation->proto();
|
const auto& hlo_module_proto = result.computation->proto();
|
||||||
@ -1829,8 +1806,8 @@ TEST_F(XlaCompilerTest, DoNotConstantFoldShapeOp) {
|
|||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
auto options = XlaCompiler::CompileOptions();
|
auto options = XlaCompiler::CompileOptions();
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(options, "test", std::move(graph), args,
|
TF_ASSERT_OK(
|
||||||
/*user_aliases=*/{}, &result));
|
compiler.CompileGraph(options, "test", std::move(graph), args, &result));
|
||||||
|
|
||||||
xla::Literal literal0 =
|
xla::Literal literal0 =
|
||||||
xla::LiteralUtil::CreateR2<int32>({{0, 1, 2}, {3, 4, 5}});
|
xla::LiteralUtil::CreateR2<int32>({{0, 1, 2}, {3, 4, 5}});
|
||||||
|
@ -332,6 +332,12 @@ class XlaBuilder {
|
|||||||
// Adds a new input/output alias. Since the input/output shape information are
|
// Adds a new input/output alias. Since the input/output shape information are
|
||||||
// not available until the computation is built, and eventual error in the
|
// not available until the computation is built, and eventual error in the
|
||||||
// arguments of this API will be detected only at computation Build() time.
|
// arguments of this API will be detected only at computation Build() time.
|
||||||
|
//
|
||||||
|
// Note: Aliasing API is 'may-alias' and only donated buffer at runtime will
|
||||||
|
// be aliased with output. If a buffer is not donated at runtime, a copy will
|
||||||
|
// be inserted by XLA to prevent buffer clobbering.
|
||||||
|
//
|
||||||
|
// Only works on TPU backend.
|
||||||
void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
||||||
const ShapeIndex& param_index) {
|
const ShapeIndex& param_index) {
|
||||||
input_output_aliases_.push_back({output_index, param_number, param_index});
|
input_output_aliases_.push_back({output_index, param_number, param_index});
|
||||||
|
@ -335,7 +335,7 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
|
|||||||
|
|
||||||
if (parent() && parent()->has_entry_computation() &&
|
if (parent() && parent()->has_entry_computation() &&
|
||||||
parent()->entry_computation() == this) {
|
parent()->entry_computation() == this) {
|
||||||
if (!Shape::Equal()(new_root_instruction->shape(),
|
if (!Shape::Equal().IgnoreLayout()(new_root_instruction->shape(),
|
||||||
root_instruction_->shape())) {
|
root_instruction_->shape())) {
|
||||||
// Rebuild input output alias config now that we have a new output shape.
|
// Rebuild input output alias config now that we have a new output shape.
|
||||||
parent()->input_output_alias_config() =
|
parent()->input_output_alias_config() =
|
||||||
|
@ -64,7 +64,8 @@ class HloInputOutputAliasConfig {
|
|||||||
// Sets up alias config from `output_index` to `param_index` at
|
// Sets up alias config from `output_index` to `param_index` at
|
||||||
// `param_number`.
|
// `param_number`.
|
||||||
Status SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
Status SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
||||||
const ShapeIndex& param_index, AliasKind kind);
|
const ShapeIndex& param_index,
|
||||||
|
AliasKind kind = AliasKind::kUserAlias);
|
||||||
|
|
||||||
// Returns the kind of alias for the given parameter number and parameter
|
// Returns the kind of alias for the given parameter number and parameter
|
||||||
// index. If no alias exists, AliasKind::kNoAlias is returned.
|
// index. If no alias exists, AliasKind::kNoAlias is returned.
|
||||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
|
#include "tensorflow/core/platform/stacktrace.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
@ -1780,8 +1780,12 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(
|
TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(
|
||||||
*module, [this](const Shape& shape) {
|
*module, [this](const Shape& shape) -> int64 {
|
||||||
|
if (target_metadata_->IsLayoutSensitive()) {
|
||||||
return target_metadata_->ShapeSize(shape);
|
return target_metadata_->ShapeSize(shape);
|
||||||
|
} else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
}));
|
}));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module));
|
TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module));
|
||||||
|
@ -214,6 +214,8 @@ class TargetVerifierMetadata {
|
|||||||
|
|
||||||
virtual std::unique_ptr<ShapeVerifier> GetVerifier() const = 0;
|
virtual std::unique_ptr<ShapeVerifier> GetVerifier() const = 0;
|
||||||
|
|
||||||
|
virtual bool IsLayoutSensitive() const = 0;
|
||||||
|
|
||||||
TargetVerifierMetadata() {}
|
TargetVerifierMetadata() {}
|
||||||
virtual ~TargetVerifierMetadata() {}
|
virtual ~TargetVerifierMetadata() {}
|
||||||
|
|
||||||
@ -245,6 +247,8 @@ class DefaultVerifierMetadata : public TargetVerifierMetadata {
|
|||||||
layout_sensitive_, allow_mixed_precision_, shape_size_function_);
|
layout_sensitive_, allow_mixed_precision_, shape_size_function_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsLayoutSensitive() const override { return layout_sensitive_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool layout_sensitive_;
|
bool layout_sensitive_;
|
||||||
bool allow_mixed_precision_;
|
bool allow_mixed_precision_;
|
||||||
|
@ -215,8 +215,12 @@ TEST_F(BufferDonationTest, SimpleWhileTupleTest) {
|
|||||||
auto gte1 = builder.AddInstruction(
|
auto gte1 = builder.AddInstruction(
|
||||||
HloInstruction::CreateGetTupleElement(f32v1_, while0, 1));
|
HloInstruction::CreateGetTupleElement(f32v1_, while0, 1));
|
||||||
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||||
|
|
||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
// Input output aliasing is only supported on TPU.
|
||||||
|
#if defined(XLA_TEST_BACKEND_TPU)
|
||||||
|
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0}));
|
||||||
|
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1}));
|
||||||
|
#endif
|
||||||
|
|
||||||
auto arg = LiteralUtil::MakeTupleFromSlices(
|
auto arg = LiteralUtil::MakeTupleFromSlices(
|
||||||
{LiteralUtil::CreateR0<int>(0), LiteralUtil::CreateR1<float>({1.1f})});
|
{LiteralUtil::CreateR0<int>(0), LiteralUtil::CreateR1<float>({1.1f})});
|
||||||
|
@ -1477,6 +1477,105 @@ TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) {
|
|||||||
EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
|
EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(RawApiTest, CompileAndExecuteWithReusedBuffersS64) {
|
||||||
|
xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::S64, {2});
|
||||||
|
xla::Shape shape =
|
||||||
|
xla::ShapeUtil::MakeTupleShape({element_shape, element_shape});
|
||||||
|
xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape(
|
||||||
|
{element_shape, element_shape, element_shape, element_shape});
|
||||||
|
xla::XlaBuilder builder("ReuseBuffer");
|
||||||
|
auto param = xla::Parameter(&builder, 0, shape, "param");
|
||||||
|
auto p0 = xla::GetTupleElement(param, 0);
|
||||||
|
auto p1 = xla::GetTupleElement(param, 1);
|
||||||
|
auto add = xla::Add(p0, p1);
|
||||||
|
auto sub = xla::Sub(p0, p1);
|
||||||
|
xla::Tuple(&builder, {add, sub, p0, p1});
|
||||||
|
|
||||||
|
// Flip the tuple literals in the input handle.
|
||||||
|
builder.SetUpAlias({1}, 0, {0});
|
||||||
|
builder.SetUpAlias({0}, 0, {1});
|
||||||
|
|
||||||
|
auto computation = builder.Build().ValueOrDie();
|
||||||
|
|
||||||
|
auto literal0 = xla::LiteralUtil::CreateR1<int64>({1, 2});
|
||||||
|
auto literal1 = xla::LiteralUtil::CreateR1<int64>({5, 9});
|
||||||
|
auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
|
||||||
|
|
||||||
|
xrt::XLAAllocation param_alloc;
|
||||||
|
*param_alloc.mutable_value() = literal.ToProto();
|
||||||
|
|
||||||
|
xrt::XLAComputation c;
|
||||||
|
auto config = c.mutable_config();
|
||||||
|
auto shapes = config->mutable_program_shape();
|
||||||
|
*shapes->add_parameters() = shape.ToProto();
|
||||||
|
*shapes->mutable_result() = return_shape.ToProto();
|
||||||
|
StoreComputationSnapshot(computation, c.mutable_hlo_snapshot());
|
||||||
|
|
||||||
|
xrt::XRTExecutionConfig e;
|
||||||
|
e.set_release_input_handles(false);
|
||||||
|
e.set_release_compilation_handle(true);
|
||||||
|
|
||||||
|
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||||
|
XrtClientSession session(root);
|
||||||
|
auto e_config =
|
||||||
|
ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
|
||||||
|
auto c_data =
|
||||||
|
ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
|
||||||
|
auto c_handle = ops::XRTCompile(root, c_data);
|
||||||
|
auto param_value = ops::Const(root.WithDevice("/device:CPU:0"),
|
||||||
|
param_alloc.SerializeAsString());
|
||||||
|
auto param_handle = ops::XRTAllocate(root, param_value);
|
||||||
|
TF_ASSERT_OK(root.status());
|
||||||
|
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
TF_EXPECT_OK(session.Run({param_handle}, &outputs));
|
||||||
|
|
||||||
|
int64 alloc_handle = outputs[0].scalar<int64>()();
|
||||||
|
|
||||||
|
// Note that we release the result handle immediately, but since we aliased
|
||||||
|
// the output buffers onto the input allocation ones (held in alloc_handle),
|
||||||
|
// we can fetch the result from there.
|
||||||
|
auto result =
|
||||||
|
ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)});
|
||||||
|
auto read_back = ops::XRTReadLiteral(root, result);
|
||||||
|
auto release = ops::XRTReleaseAllocationHandle(
|
||||||
|
root.WithControlDependencies(read_back), result);
|
||||||
|
TF_ASSERT_OK(root.status());
|
||||||
|
|
||||||
|
TF_EXPECT_OK(
|
||||||
|
session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
|
||||||
|
|
||||||
|
xla::Literal exec_literal = ReadOutputLiteral(outputs, 0);
|
||||||
|
auto exec_literal_parts = exec_literal.DecomposeTuple();
|
||||||
|
ASSERT_EQ(exec_literal_parts.size(), 4);
|
||||||
|
|
||||||
|
EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0));
|
||||||
|
EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1));
|
||||||
|
|
||||||
|
// Now we read back the original input handle values, which at this point
|
||||||
|
// should contain the result of the XLA computation.
|
||||||
|
auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle));
|
||||||
|
TF_ASSERT_OK(root.status());
|
||||||
|
auto release_handle = ops::XRTReleaseAllocationHandle(
|
||||||
|
root.WithControlDependencies(read_handle), Input(alloc_handle));
|
||||||
|
TF_ASSERT_OK(root.status());
|
||||||
|
|
||||||
|
TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle},
|
||||||
|
{release_handle}, &outputs));
|
||||||
|
|
||||||
|
xla::Literal return_literal = ReadOutputLiteral(outputs, 0);
|
||||||
|
|
||||||
|
auto expected_literal0 = xla::LiteralUtil::CreateR1<int64>({6, 11});
|
||||||
|
auto expected_literal1 = xla::LiteralUtil::CreateR1<int64>({-4, -7});
|
||||||
|
// The first element of the computation returned tuple would be the add
|
||||||
|
// (expected_literal0), but since we flipped the buffers, the sub
|
||||||
|
// (expected_literal1) should come first.
|
||||||
|
auto expected_literal =
|
||||||
|
xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0});
|
||||||
|
|
||||||
|
EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
|
TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
|
||||||
xrt::XLAAllocation p0;
|
xrt::XLAAllocation p0;
|
||||||
*p0.mutable_value() = xla::LiteralUtil::CreateR0<int64>(11031965).ToProto();
|
*p0.mutable_value() = xla::LiteralUtil::CreateR0<int64>(11031965).ToProto();
|
||||||
|
Loading…
Reference in New Issue
Block a user