STT-tensorflow/tensorflow/compiler/tf2xla/xla_compiler_test.cc
Ken Franko a4219770e9 Fix build and reenable xla_compiler_test.
Remove const type from vectors.

PiperOrigin-RevId: 328157979
Change-Id: I8df58b0b23831b842c04c3243290ca61ecf7f4aa
2020-08-24 10:03:56 -07:00

1960 lines
79 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/list_ops.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
class XlaCompilerTest : public ::testing::Test {
protected:
void SetUp() override {
client_ = xla::ClientLibrary::LocalClientOrDie();
XlaOpRegistry::RegisterCompilationKernels();
FunctionDefLibrary flib;
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
}
XlaCompiler::Options DefaultOptions() {
XlaCompiler::Options options;
options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
options.client = client_;
options.flib_def = flib_def_.get();
return options;
}
FunctionLibraryDefinition* LocalFlibDef(XlaCompiler* compiler) {
return compiler->local_flib_def_.get();
}
xla::Client* client_;
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
};
namespace {
// Helper class to test the ability to pass resources through to XLA
// compiled kernels.
class DummyResourceForTest : public ResourceBase {
public:
string DebugString() const override { return "dummy"; }
void Increment() { ++value_; }
int Get() { return value_; }
private:
int value_ = 0;
};
class DummyReadResourceOp : public XlaOpKernel {
public:
explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ResourceMgr* rm = ctx->op_kernel_context()->resource_manager();
OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
DummyResourceForTest* dummy;
OP_REQUIRES_OK(ctx, rm->Lookup<DummyResourceForTest>(
rm->default_container(), "dummy", &dummy));
dummy->Increment();
dummy->Unref();
ctx->SetOutput(0, ctx->Input(0));
ctx->SetOutput(1, ctx->Input(0));
}
};
class DummyReadResourceCC {
public:
DummyReadResourceCC(const Scope& scope, const Input& value) {
if (!scope.ok()) return;
auto _value = ops::AsNodeOut(scope, value);
if (!scope.ok()) return;
Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource");
auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value);
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return;
scope.UpdateStatus(scope.DoShapeInference(ret));
if (!scope.ok()) return;
this->output1_ = Output(ret, 0);
this->output2_ = Output(ret, 1);
}
Output output1_;
Output output2_;
};
REGISTER_OP("DummyReadResource")
.Input("input: int32")
.Output("output1: int32")
.Output("output2: int32")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
A dummy Op.
input: dummy input.
output1: dummy output.
output2: dummy output.
)doc");
REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp);
// DummyDuplicateOp is present purely to test multiple REGISTER_XLA_OP calls
// on the same Op name below.
class DummyDuplicateOp : public XlaOpKernel {
public:
explicit DummyDuplicateOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ctx->SetOutput(0, ctx->Input(0));
}
};
REGISTER_OP("DummyDuplicateOp")
.Input("input: int32")
.Output("output: int32")
.Doc(R"doc(
A dummy Op.
input: dummy input.
output: dummy output.
)doc");
REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT),
DummyDuplicateOp);
REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT),
DummyDuplicateOp);
// Tests compilation and execution of an empty graph.
TEST_F(XlaCompilerTest, EmptyReturnValues) {
XlaCompiler compiler(DefaultOptions());
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph),
/*args=*/{}, &result));
TF_ASSERT_OK(client_->Execute(*result.computation, {}).status());
}
// Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest, Simple) {
// Builds a graph that adds two Tensors.
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
auto c = ops::Add(scope.WithOpName("C"), a, b);
auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
// Tests that the generated computation works.
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation of a graph where the _Retval node is not necessarily last
// amongst the graph nodes in construction order, and always_return_tuple is
// false. Regression test for bug where the wrong value was returned.
TEST_F(XlaCompilerTest, OutOfOrderGraph) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
// The _Retval node is not last in construction order.
auto d = ops::_Retval(scope.WithOpName("D"), a, 0);
auto c = ops::Add(scope.WithOpName("C"), a, b);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompileOptions compile_options;
compile_options.always_return_tuple = false;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
// Tests that the generated computation works.
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
}
// Tests that the compiler can correctly propagate the layout assigned by
// shape_representation_fn_ to resource returns that have not been written to.
TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForUnwrittenResource) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kResource;
args[0].resource_kind = XlaResource::kVariable;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 3});
auto options = DefaultOptions();
options.shape_representation_fn =
[](const TensorShape& shape, DataType dt,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
return xla_shape;
};
// Compiles the graph.
XlaCompiler compiler(options);
XlaCompiler::CompilationResult result;
XlaCompiler::CompileOptions compile_options;
compile_options.return_updated_values_for_all_resources = true;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
xla::Shape transposed =
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
// Check that the return shapes are correctly tranposed.
EXPECT_EQ(result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape({transposed}));
}
// Tests that the compiler can correctly propagate fast mem attribute for input
// resource variable.
TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForFastMemVar) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kResource;
args[0].resource_kind = XlaResource::kVariable;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 3});
args[0].fast_mem = true;
auto options = DefaultOptions();
int fast_mem_arg_count = 0;
options.shape_representation_fn =
[&fast_mem_arg_count](const TensorShape& shape, DataType dt,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
if (use_fast_memory) {
fast_mem_arg_count++;
}
return xla_shape;
};
// Compiles the graph.
XlaCompiler compiler(options);
XlaCompiler::CompilationResult result;
XlaCompiler::CompileOptions compile_options;
compile_options.return_updated_values_for_all_resources = true;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
// Count 2: one for argument, one for the return value.
EXPECT_EQ(fast_mem_arg_count, 2);
}
// Tests that the compiler can correctly propagate the layout assigned by
// shape_representation_fn_ to return types.
TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
// Adds an identity op around the resource to make sure identity ops propagate
// resources correctly.
auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
auto write = ops::AssignAddVariableOp(scope, identity, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 3});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2, 3});
auto options = DefaultOptions();
options.shape_representation_fn =
[](const TensorShape& shape, DataType dt,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
return xla_shape;
};
// Compiles the graph.
XlaCompiler compiler(options);
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
xla::Shape transposed =
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
// Check that the return shapes are correctly tranposed.
EXPECT_EQ(result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
}
// The layout of resource variable shouldn't change after transpose
TEST_F(XlaCompilerTest, TransposeVariables) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
// Adds an identity op around the resource to make sure identity ops propagate
// resources correctly.
auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
auto write = ops::AssignAddVariableOp(scope, identity, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
auto transposed_read = ops::Transpose(scope, read, {1, 0});
auto reshape = ops::Reshape(scope, transposed_read, {2, 3});
auto d = ops::_Retval(scope.WithOpName("D"), reshape, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 3});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2, 3});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose",
std::move(graph), args, &result));
xla::Shape transposed =
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0});
// Check that the return shapes are correctly tranposed.
EXPECT_EQ(result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
}
// Unranked fake param returns a 0 shaped tensor.
TEST_F(XlaCompilerTest, UnrankedFakeParam) {
Scope scope = Scope::NewRootScope().ExitOnError();
PartialTensorShape shape;
auto a = ops::FakeParam(scope, DT_INT32, shape);
auto ret = ops::_Retval(scope.WithOpName("D"), a, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "compile",
std::move(graph), {}, &result));
// Check that the return shapes are correctly tranposed.
EXPECT_EQ(result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::S32, {0})}));
}
// Tests that the compiler doesn't reorder the parameters.
TEST_F(XlaCompilerTest, MixedOrderArguments) {
for (bool swap_order : {false, true}) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto var =
ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, swap_order ? 0 : 1);
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, swap_order ? 1 : 0);
// Adds an identity op around the resource to make sure identity ops
// propagate resources correctly.
auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
auto write = ops::AssignAddVariableOp(scope, identity, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2});
if (swap_order) {
// Even after swapping arguments, the compiler should maintain the new
// ordering of parameters.
std::swap(args[0], args[1]);
}
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompileOptions compile_options;
compile_options.always_return_tuple = false;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1));
}
}
TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
// Builds a graph that adds reshapes a tensor, but with the shape not
// statically known.
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
auto c = ops::Reshape(scope.WithOpName("C"), a, b);
auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
Status status =
compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape",
std::move(graph), args, &result);
EXPECT_FALSE(status.ok());
EXPECT_TRUE(
absl::StrContains(status.error_message(), "depends on a parameter"))
<< status.error_message();
EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node C}}"))
<< status.error_message();
EXPECT_TRUE(absl::StrContains(status.error_message(),
"must be a compile-time constant"))
<< status.error_message();
}
// Tests handling of compile-time constant outputs.
TEST_F(XlaCompilerTest, ConstantOutputs) {
// Builds a graph with one compile-time constant output and one data-dependent
// output, i.e.,
// func(a) { b=7; c=-a; return b, c; }
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::Const<int32>(scope.WithOpName("B"), 7);
auto c = ops::Neg(scope.WithOpName("C"), a);
auto d = ops::_Retval(scope.WithOpName("D"), b, 0);
auto e = ops::_Retval(scope.WithOpName("E"), c, 1);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
XlaCompiler::Options options = DefaultOptions();
XlaCompiler compiler(options);
{
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompileOptions compile_options;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
std::move(graph_copy), args, &result));
ASSERT_EQ(2, result.outputs.size());
EXPECT_FALSE(result.outputs[0].is_constant);
EXPECT_FALSE(result.outputs[1].is_constant);
// Tests that the generated computation works.
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
xla::Literal actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
xla::Literal expected =
xla::LiteralUtil::MakeTuple({&expected0, &expected1});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
}
}
TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) {
// Define a function with one compile-time constant output and one
// data-dependent output.
// @function.Defun(noinline=True)
// foo(a) {b=7; return b, a; }
const Tensor seven = test::AsScalar<int>(7);
FunctionDef fdef = FunctionDefHelper::Create(
"foo", {"a_0:int32"}, {"const:int32", "a:int32"}, {},
{
{{"Const"}, "Const", {}, {{"dtype", DT_INT32}, {"value", seven}}},
},
{{"a", "a_0"}, {"const", "Const:output:0"}});
(*fdef.mutable_attr())["_noinline"].set_b(true);
FunctionDefLibrary fdef_lib;
*(fdef_lib.add_function()) = fdef;
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
Scope scope = Scope::NewRootScope().ExitOnError();
TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
auto arg = ops::_Arg(scope.WithOpName("input_arg"), DT_INT32, 0);
NodeDef foo;
foo.set_name("foo");
foo.set_op("foo");
*foo.add_input() = "input_arg";
Status status;
scope.graph()->AddNode(foo, &status);
TF_ASSERT_OK(status);
NodeDef retval_1;
retval_1.set_name("retval_0");
retval_1.set_op(FunctionLibraryDefinition::kRetOp);
*retval_1.add_input() = "foo";
(*retval_1.mutable_attr())["T"].set_type(DT_INT32);
(*retval_1.mutable_attr())["index"].set_i(0);
scope.graph()->AddNode(retval_1, &status);
TF_ASSERT_OK(status);
NodeDef retval_2;
retval_2.set_name("retval_1");
retval_2.set_op(FunctionLibraryDefinition::kRetOp);
*retval_2.add_input() = "foo:1";
(*retval_2.mutable_attr())["T"].set_type(DT_INT32);
(*retval_2.mutable_attr())["index"].set_i(1);
scope.graph()->AddNode(retval_2, &status);
TF_ASSERT_OK(status);
TF_ASSERT_OK(scope.ToGraph(graph.get()));
}
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({1});
XlaCompiler::Options options = DefaultOptions();
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
options.flib_def = &flib_def;
XlaCompiler compiler(options);
XlaCompiler::CompileOptions compile_options;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
std::move(graph), args, &result));
ASSERT_EQ(2, result.outputs.size());
EXPECT_FALSE(result.outputs[1].is_constant);
}
// Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest, ResourceManager) {
// Builds a graph that calls the dummy resource Op.
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = DummyReadResourceCC(scope.WithOpName("B"), a);
auto c = ops::Add(scope.WithOpName("C"), b.output2_, b.output1_);
auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the argument.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
DummyResourceForTest* resource = new DummyResourceForTest();
// Compiles the graph.
auto options = DefaultOptions();
std::function<Status(ResourceMgr*)> populate_function =
[resource](ResourceMgr* rm) {
resource->Ref();
return rm->Create(rm->default_container(), "dummy", resource);
};
options.populate_resource_manager = &populate_function;
XlaCompiler compiler(options);
EXPECT_EQ(0, resource->Get());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
std::move(graph), args, &result));
EXPECT_EQ(1, resource->Get());
resource->Unref();
}
// Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest, DeterministicCompilation) {
// Builds a graph that contains a node with two output edges. The compiler
// should always traverse them in the same order.
const int64 test_count = 2;
std::vector<XlaCompiler::CompilationResult> results(test_count);
for (int64 i = 0; i < test_count; ++i) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::Neg(scope.WithOpName("B"), a);
auto c = ops::Neg(scope.WithOpName("C"), a);
auto d = ops::Add(scope.WithOpName("D"), b, c);
auto e = ops::_Retval(scope.WithOpName("E"), d, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the argument.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
// Compiles the graph.
auto options = DefaultOptions();
XlaCompiler compiler(options);
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
std::move(graph), args, &results[i]));
}
for (int64 i = 1; i < test_count; ++i) {
const auto& m1 = results[i - 1].computation->proto();
const auto& m2 = results[i].computation->proto();
ASSERT_EQ(m1.computations_size(), m2.computations_size());
// Check if every hlo computation is the same.
for (int k = 0; k < m1.computations_size(); k++) {
const auto& c1 = m1.computations(k);
const auto& c2 = m2.computations(k);
ASSERT_EQ(c1.instructions_size(), c2.instructions_size());
for (int j = 0; j < c1.instructions_size(); j++) {
auto instr1 = c1.instructions(j);
auto instr2 = c2.instructions(j);
instr1.clear_name();
instr1.clear_id();
instr1.clear_operand_ids();
instr2.clear_name();
instr2.clear_id();
instr2.clear_operand_ids();
// The names of instructions were uniquified by the XlaBuilder and the
// unique ids may be different, the rest of the fields should be
// identical.
string str1, str2;
LOG(INFO) << "instr1 = " << instr1.DebugString();
LOG(INFO) << "instr2 = " << instr2.DebugString();
instr1.AppendPartialToString(&str1);
instr2.AppendPartialToString(&str2);
EXPECT_EQ(str1, str2);
}
}
}
}
// Tests a computation that receives a TensorArray resource as input and
// updates it.
TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
auto flow = ops::Const<float>(scope, {});
auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2");
auto index = ops::Const<int32>(scope, 1);
auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index,
grad2.flow_out);
auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32);
auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kResource;
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({});
args[0].max_array_size = 2;
args[0].tensor_array_gradients = {"grad2"};
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
ASSERT_EQ(1, result.resource_updates.size());
const XlaCompiler::ResourceUpdate& update = result.resource_updates[0];
EXPECT_EQ(0, update.input_index);
EXPECT_EQ(DT_INT32, update.type);
EXPECT_EQ((std::set<string>{"grad1", "grad2"}),
update.tensor_array_gradients_accessed);
// Tests that the generated computation works.
xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(input).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
xla::Literal output_resource =
xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
xla::Literal expected_literal =
xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
auto flow = ops::Const<float>(scope, {});
auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
auto index = ops::Const<int32>(scope, 1);
auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kResource;
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({});
args[0].max_array_size = 2;
args[0].tensor_array_gradients = {"grad1"};
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
EXPECT_EQ(0, result.resource_updates.size());
}
// Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
auto flow = ops::Const<float>(scope, {});
auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2");
auto index = ops::Const<int32>(scope, 1);
auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kResource;
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({});
args[0].max_array_size = 2;
args[0].tensor_array_gradients = {"grad1"};
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
EXPECT_EQ(1, result.resource_updates.size());
}
// Tests CompileFunction with undefined function fails.
TEST_F(XlaCompilerTest, UndefinedFunctionFails) {
XlaCompiler compiler(DefaultOptions());
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
XlaCompiler::CompilationResult result;
NameAttrList name_attr;
name_attr.set_name("Function_NotDefined_");
Status status =
compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
/*args=*/{}, &result);
EXPECT_FALSE(status.ok());
EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
<< status.error_message();
}
FunctionDef FillFn() {
return FunctionDefHelper::Define(
// Name
"FillFn",
// Args
{"x: T", "dims: int32"},
// Return values
{"y: T"},
// Attr def
{"T: {float, double, int32, int64}"},
// Nodes
{{{"y"}, "Fill", {"dims", "x"}, {{"T", "$T"}}}});
}
TEST_F(XlaCompilerTest, FunctionCallWithConstants) {
// Certain operations in a function, "Fill" for example, requires the
// operator's argument to be a compile-time constant instead of a parameter.
// This testcase tests if XlaCompiler can handle such operators inside
// function calls.
XlaCompiler compiler(DefaultOptions());
FunctionDefLibrary flib;
*flib.add_function() = FillFn();
TF_ASSERT_OK(flib_def_->AddFunctionDef(FillFn()));
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Scope scope = Scope::NewRootScope().ExitOnError();
auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
NodeDef def;
TF_ASSERT_OK(NodeDefBuilder("fill", "FillFn", flib_def_.get())
.Input(value.name(), 0, DT_INT32)
.Input(shape.name(), 1, DT_INT32)
.Finalize(&def));
Status status;
Node* fill = scope.graph()->AddNode(def, &status);
TF_ASSERT_OK(status);
TF_ASSERT_OK(scope.DoShapeInference(fill));
scope.graph()->AddEdge(value.node(), 0, fill, 0);
scope.graph()->AddEdge(shape.node(), 0, fill, 1);
auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the argument.
std::vector<XlaCompiler::Argument> args;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
std::move(graph), args, &result));
}
// Tests CompileFunction with a local function lookup failing, fails with
// informative error about both lookups.
TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
XlaCompiler compiler(DefaultOptions());
auto local_flib_def = LocalFlibDef(&compiler);
TF_ASSERT_OK(local_flib_def->AddFunctionDef(test::function::XTimesTwo()));
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
XlaCompiler::CompilationResult result;
NameAttrList name_attr;
name_attr.set_name("XTimesTwo");
Status status =
compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
/*args=*/{}, &result);
ASSERT_FALSE(status.ok());
// Flib lookup failure.
EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
<< status.error_message();
// Local flib lookup failure.
EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found"))
<< status.error_message();
}
void RunAndCheckVariablesComputation(
xla::Client* client, const XlaCompiler::CompilationResult& result) {
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
client->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie();
xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144});
xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143});
xla::Literal expected_literal =
xla::LiteralUtil::MakeTuple({&expected0, &expected1});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests a simple graph that reads and writes a variable.
TEST_F(XlaCompilerTest, Variables) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
// Adds an identity op around the resource to make sure identity ops propagate
// resources correctly.
auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
auto write = ops::AssignAddVariableOp(scope, identity, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
RunAndCheckVariablesComputation(client_, result);
}
TEST_F(XlaCompilerTest, ResultLayoutSingle) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::_Retval(scope.WithOpName("RET"), a, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 3});
auto options = DefaultOptions();
// Sets the representation function to return a non-default layout.
options.shape_representation_fn =
[](const TensorShape& shape, DataType type,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
return xla_shape;
};
// Compiles the graph.
XlaCompiler compiler(options);
XlaCompiler::CompilationResult result;
auto compile_options = XlaCompiler::CompileOptions();
compile_options.always_return_tuple = false;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph),
args, &result));
EXPECT_TRUE(xla::ShapeUtil::Equal(
result.xla_output_shape,
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1})));
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
result.xla_output_shape);
}
TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::_Retval(scope.WithOpName("RET1"), a, 0);
auto c = ops::_Retval(scope.WithOpName("RET2"), a, 1);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 3});
auto options = DefaultOptions();
// Sets the representation function to return a non-default layout.
options.shape_representation_fn =
[](const TensorShape& shape, DataType type,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
return xla_shape;
};
// Compiles the graph.
XlaCompiler compiler(options);
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id",
std::move(graph), args, &result));
xla::Shape result_shape =
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
EXPECT_TRUE(xla::ShapeUtil::Equal(
result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape({result_shape, result_shape})));
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
result.xla_output_shape);
}
// Tests a simple graph that reads and writes a variable.
TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kResource;
args[0].resource_kind = XlaResource::kVariable;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
// Tests that the generated computation works.
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param1_data.get()})
.ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, ReturnResourceHandle) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
// Adds an identity op around the resource to make sure identity ops propagate
// resources correctly.
auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
auto write = ops::AssignAddVariableOp(scope, identity, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
auto r = ops::_Retval(scope.WithOpName("R"), var, 0);
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 1);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
RunAndCheckVariablesComputation(client_, result);
}
xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
auto write = ops::AssignAddVariableOp(scope, var, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
return std::move(graph);
}
// Tests a simple graph that reads and writes a variable, with a
// shape_representation_fn passed to the compiler that flattens all
// variable tensors to vectors.
TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 2});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2, 2});
// Compiles the graph.
XlaCompiler::Options options = DefaultOptions();
options.shape_representation_fn =
[](const TensorShape& shape, DataType type,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::PrimitiveType ptype;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
};
XlaCompiler compiler(options);
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = false; // Only reshape variables.
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
client_->GetComputationShape(*result.computation));
ASSERT_EQ(program_shape->parameters_size(), 2);
EXPECT_TRUE(
xla::ShapeUtil::Compatible(program_shape->parameters(0),
xla::ShapeUtil::MakeShape(xla::S32, {2, 2})));
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->result(),
xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::S32, {2, 2}),
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
xla::Literal param0_literal =
xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
xla::Literal param1_literal =
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
xla::Literal expected0 =
xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
xla::Literal expected_literal =
xla::LiteralUtil::MakeTuple({&expected0, &expected1});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 2});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2, 2});
// Compiles the graph.
XlaCompiler::Options options = DefaultOptions();
options.shape_representation_fn =
[](const TensorShape& shape, DataType type,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::PrimitiveType ptype;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
};
XlaCompiler compiler(options);
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true; // Reshape args and retvals.
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
client_->GetComputationShape(*result.computation));
ASSERT_EQ(program_shape->parameters_size(), 2);
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4})));
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->result(),
xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::S32, {4}),
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
xla::Literal param0_literal =
xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
xla::Literal param1_literal =
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
xla::Literal expected_literal =
xla::LiteralUtil::MakeTuple({&expected0, &expected1});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests a graph which has a function with an invalid op.
TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
XlaCompiler compiler(DefaultOptions());
FunctionDefLibrary flib;
FunctionDef fn = FillFn();
NodeDef* node = fn.add_node_def();
node->set_name("Invalid");
node->set_op("InvalidOp"); /* unsupported op */
node = fn.add_node_def();
node->set_name("Switch");
node->set_op("Switch"); /* control flow node */
*flib.add_function() = fn;
TF_ASSERT_OK(flib_def_->AddFunctionDef(fn));
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Scope scope = Scope::NewRootScope().ExitOnError();
auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib));
NodeDef def;
TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get())
.Input(value.name(), 0, DT_INT32)
.Input(shape.name(), 1, DT_INT32)
.Finalize(&def));
Status status;
Node* fill = scope.graph()->AddNode(def, &status);
TF_ASSERT_OK(status);
TF_ASSERT_OK(scope.DoShapeInference(fill));
scope.graph()->AddEdge(value.node(), 0, fill, 0);
scope.graph()->AddEdge(shape.node(), 0, fill, 1);
auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
TF_ASSERT_OK(scope.ToGraph(graph.get()));
std::vector<XlaCompiler::Argument> args;
XlaCompiler::CompilationResult result;
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
std::move(graph), args, &result);
ASSERT_FALSE(status.ok());
EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp"))
<< status.error_message();
EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}"))
<< status.error_message();
}
// Tests a graph which has a node with invalid data type.
TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
NodeDef shape;
shape.set_name("Shape");
shape.set_op("Shape");
(*shape.mutable_attr())["T"].set_type(DT_INT32);
(*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */
Status status;
Node* shape_node = graph->AddNode(shape, &status);
TF_ASSERT_OK(status);
graph->AddControlEdge(graph->source_node(), shape_node);
std::vector<XlaCompiler::Argument> args;
XlaCompiler::CompilationResult result;
XlaCompiler compiler(DefaultOptions());
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
std::move(graph), args, &result);
ASSERT_FALSE(status.ok());
EXPECT_TRUE(absl::StrContains(status.error_message(),
"is not in the list of allowed values"))
<< status.error_message();
EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}"))
<< status.error_message();
}
TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
NodeDef no_op;
no_op.set_name("NoOp");
no_op.set_op("NoOp");
Status status;
graph->AddNode(no_op, &status);
TF_ASSERT_OK(status);
std::vector<XlaCompiler::Argument> args;
XlaCompiler compiler(DefaultOptions());
// No control edge linking NoOp with source/sink.
{
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
std::move(graph_copy), args, &result));
}
}
class DummySideEffectingOp : public XlaOpKernel {
public:
explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken(
name(), xla::CreateToken(ctx->builder())));
}
};
REGISTER_OP("DummySideEffectingOp");
REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp);
TEST_F(XlaCompilerTest, TokenInputAndOutput) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
NodeDef side_effecting_op;
side_effecting_op.set_name("DummySideEffectingOp");
side_effecting_op.set_op("DummySideEffectingOp");
AddNodeAttr(kXlaTokenInputNodesAttrName,
std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op);
AddNodeAttr(kXlaOriginalOutsideCompilationNodeName, side_effecting_op.name(),
&side_effecting_op);
Status status;
graph->AddNode(side_effecting_op, &status);
TF_ASSERT_OK(status);
EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get()));
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kResource;
args[0].resource_kind = XlaResource::kVariable;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 2});
{
// The case for entry computation: we don't add token input/output. Instead,
// we use CreateToken HLO to create the entry token.
XlaCompiler::CompileOptions options;
options.is_entry_computation = true;
options.add_token_input_output = false;
options.return_updated_values_for_all_resources = true;
XlaCompiler compiler(DefaultOptions());
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
args, &result));
EXPECT_EQ(result.xla_input_shapes.size(), 1);
EXPECT_TRUE(result.xla_output_shape.IsTuple());
EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
}
{
// The case for non-entry computation (e.g. while loop body). We add token
// input/output.
XlaCompiler::CompileOptions options;
options.is_entry_computation = false;
options.add_token_input_output = true;
options.return_updated_values_for_all_resources = true;
XlaCompiler compiler(DefaultOptions());
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
args, &result));
EXPECT_EQ(result.xla_input_shapes.size(), 2);
EXPECT_TRUE(result.xla_input_shapes[1].IsToken());
EXPECT_TRUE(result.xla_output_shape.IsTuple());
EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 2);
EXPECT_TRUE(xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 1)
.IsToken());
}
}
TEST_F(XlaCompilerTest, OpsWithTensorListInput) {
FunctionDefLibrary fdef_lib;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
// Build cond fn for While.
{
Scope scope = Scope::NewRootScope().ExitOnError();
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
auto result = ops::Const<bool>(scope, {true}, {});
ops::_Retval(scope.WithOpName("ret"), result, 0);
TF_ASSERT_OK(scope.ToGraph(graph.get()));
FunctionDef fdef;
TF_ASSERT_OK(GraphToFunctionDef(*graph, "cond", &fdef));
TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
}
// Build body fn for While.
{
Scope scope = Scope::NewRootScope().ExitOnError();
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
ops::_Retval(scope.WithOpName("ret"), arg, 0);
TF_ASSERT_OK(scope.ToGraph(graph.get()));
FunctionDef fdef;
TF_ASSERT_OK(GraphToFunctionDef(*graph, "body", &fdef));
TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
}
Scope scope = Scope::NewRootScope().ExitOnError();
auto element_shape = ops::Const<int32>(scope, {1}, {1});
auto max_elements = ops::Const<int32>(scope, {10}, {});
auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
std::initializer_list<Output> out = {arg, arg};
auto add_n = ops::AddN(scope, out);
NameAttrList cond_fn, body_fn;
cond_fn.set_name("cond");
body_fn.set_name("body");
auto while_op =
ops::While(scope, std::initializer_list<Input>{arg}, cond_fn, body_fn);
auto ret0 = ops::_Retval(scope.WithOpName("ret0"), add_n, 0);
auto ret1 = ops::_Retval(scope.WithOpName("ret1"), while_op.output[0], 1);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kTensorList;
xla::Shape tensor_list_element_shape;
TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{1},
&tensor_list_element_shape));
xla::Shape index_shape;
TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{}, &index_shape));
std::vector<xla::Shape> shapes{tensor_list_element_shape, index_shape};
xla::Shape arg_shape = xla::ShapeUtil::MakeTupleShape(shapes);
args[0].shape = arg_shape;
// Compiles the graph.
XlaCompiler::Options options = DefaultOptions();
options.flib_def = &flib_def;
XlaCompiler compiler(options);
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
ASSERT_EQ(result.outputs.size(), 2);
const XlaCompiler::OutputDescription& output0 = result.outputs[0];
ASSERT_TRUE(output0.is_tensor_list);
const XlaCompiler::OutputDescription& output1 = result.outputs[1];
ASSERT_TRUE(output1.is_tensor_list);
}
// Test the compiler supports WhileOp with a loop body where DT_RESOURCE
// variables are both inputs and outputs.
TEST_F(XlaCompilerTest, WhileWithResources) {
FunctionDefLibrary fdef_lib;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
// Build cond fn for While.
{
Scope scope = Scope::NewRootScope().ExitOnError();
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1);
auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_RESOURCE, 2);
auto less = ops::Less(scope, arg0, ops::Const<int32>(scope, 10));
(void)ops::_Retval(scope.WithOpName("ret"), less, 0);
TF_ASSERT_OK(scope.ToGraph(graph.get()));
FunctionDef fdef;
TF_ASSERT_OK(GraphToFunctionDef(*graph, "cond", &fdef));
TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
}
// Build body fn for While.
{
Scope scope = Scope::NewRootScope().ExitOnError();
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1);
auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_RESOURCE, 2);
auto read1 = ops::ReadVariableOp(scope.WithOpName("read1"), arg1, DT_INT32);
auto plus_read1 = ops::Add(scope, arg0, read1);
auto read2 = ops::ReadVariableOp(scope.WithOpName("read2"), arg2, DT_INT32);
auto minus_read2 = ops::Sub(scope, plus_read1, read2);
(void)ops::_Retval(scope.WithOpName("ret0"), minus_read2, 0);
(void)ops::_Retval(scope.WithOpName("ret1"), arg1, 1);
(void)ops::_Retval(scope.WithOpName("ret2"), arg2, 2);
TF_ASSERT_OK(scope.ToGraph(graph.get()));
FunctionDef fdef;
TF_ASSERT_OK(GraphToFunctionDef(*graph, "body", &fdef));
TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
}
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1);
auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_RESOURCE, 2);
NameAttrList cond_fn, body_fn;
cond_fn.set_name("cond");
body_fn.set_name("body");
auto while_op = ops::While(
scope, std::initializer_list<Input>{arg0, arg1, arg2}, cond_fn, body_fn);
(void)ops::_Retval(scope.WithOpName("ret0"), while_op.output[0], 0);
(void)ops::_Retval(scope.WithOpName("ret1"), while_op.output[1], 1);
(void)ops::_Retval(scope.WithOpName("ret2"), while_op.output[2], 2);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(3);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({});
args[2].kind = XlaCompiler::Argument::kResource;
args[2].resource_kind = XlaResource::kVariable;
args[2].initialized = true;
args[2].type = DT_INT32;
args[2].shape = TensorShape({});
// Compiles the graph.
XlaCompiler::Options options = DefaultOptions();
options.flib_def = &flib_def;
XlaCompiler compiler(options);
XlaCompiler::CompileOptions compile_options = XlaCompiler::CompileOptions();
compile_options.return_updated_values_for_all_resources = true;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "tested_while_with_vars",
std::move(graph), args, &result));
ASSERT_EQ(result.outputs.size(), 3);
const XlaCompiler::OutputDescription& output1 = result.outputs[1];
ASSERT_EQ(output1.input_index, 1);
const XlaCompiler::OutputDescription& output2 = result.outputs[2];
ASSERT_EQ(output2.input_index, 2);
// Tests that the generated computation works.
xla::Literal literal0 = xla::LiteralUtil::CreateR0<int32>(0);
xla::Literal literal1 = xla::LiteralUtil::CreateR0<int32>(2);
xla::Literal literal2 = xla::LiteralUtil::CreateR0<int32>(1);
std::unique_ptr<xla::GlobalData> data0 =
client_->TransferToServer(literal0).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> data1 =
client_->TransferToServer(literal1).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> data2 =
client_->TransferToServer(literal2).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation,
{data0.get(), data1.get(), data2.get()})
.ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(10);
xla::Literal expected1 = xla::LiteralUtil::CreateR0<int32>(2);
xla::Literal expected2 = xla::LiteralUtil::CreateR0<int32>(1);
xla::Literal expected_literal =
xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) {
// Builds a graph that returns its only argument.
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::_Retval(scope.WithOpName("B"), a, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Sets _XlaSharding attribute for the _Retval node.
auto node_name_index = graph->BuildNodeNameIndex();
Node* ret_node = node_name_index["B"];
ASSERT_NE(ret_node, nullptr);
xla::Array<int64> tile_assignment({2});
tile_assignment.FillIota(0);
xla::HloSharding sharding = xla::HloSharding::Tile(tile_assignment);
ret_node->AddAttr("_XlaSharding", sharding.ToProto().SerializeAsString());
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "test",
std::move(graph), args, &result));
// Tests that we set sharding on the root TUPLE instruction.
const auto& hlo_module_proto = result.computation->proto();
ASSERT_EQ(hlo_module_proto.computations_size(), 1);
const auto& hlo_computation_proto = hlo_module_proto.computations(0);
absl::optional<xla::HloInstructionProto> root_instruction_proto;
for (const auto& inst : hlo_computation_proto.instructions()) {
if (inst.id() == hlo_computation_proto.root_id()) {
root_instruction_proto = inst;
break;
}
}
ASSERT_TRUE(root_instruction_proto);
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::S32, {2})});
xla::HloSharding tuple_sharding = xla::HloSharding::Tuple(
tuple_shape, std::vector<xla::HloSharding>{sharding});
EXPECT_EQ(root_instruction_proto->sharding().SerializeAsString(),
tuple_sharding.ToProto().SerializeAsString());
}
TEST_F(XlaCompilerTest, DoNotConstantFoldShapeOp) {
// When we have a dynamic shape input followed by a Shape op, the Shape op
// should return dynamic size:
//
// [2, b] // b's static size is 3 and dynamic size is 2
// |
// Size // should return 2, 2
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
auto shape = ops::Shape(scope.WithOpName("shape"), a);
(void)ops::_Retval(scope.WithOpName("retval"), shape, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 3});
// Indicates that first dimension is dynamic, and arg 1 holds the runtime
// value of it.
args[0].dynamic_dim_to_arg_num_map.insert({1, 1});
// Arg 1 holds the dynamic size.
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
args[1].shape = TensorShape({});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
auto options = XlaCompiler::CompileOptions();
TF_ASSERT_OK(
compiler.CompileGraph(options, "test", std::move(graph), args, &result));
xla::Literal literal0 =
xla::LiteralUtil::CreateR2<int32>({{0, 1, 2}, {3, 4, 5}});
xla::Literal literal1 = xla::LiteralUtil::CreateR0<int32>(2);
std::unique_ptr<xla::GlobalData> data0 =
client_->TransferToServer(literal0).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> data1 =
client_->TransferToServer(literal1).ConsumeValueOrDie();
// Prepare arguments.
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {data0.get(), data1.get()})
.ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
// The dynamic size of the op is <2, 2> instead of static size <2, 3>
xla::Literal expected = xla::LiteralUtil::CreateR1<int32>({2, 2});
xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, AliasResourceUpdates) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::Const<int32>(scope.WithOpName("A"), {1, 2});
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
auto write = ops::AssignAddVariableOp(scope, var, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
auto d = ops::_Retval(scope.WithOpName("D"), read, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kConstant;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
args[0].constant_value = Tensor(DT_INT32, {1, 1});
args[0].initialized = true;
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2});
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompileOptions compile_options;
compile_options.alias_resource_update = true;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
const xla::HloInputOutputAliasProto& alias =
result.computation->proto().input_output_alias();
EXPECT_EQ(alias.entries_size(), 1);
EXPECT_EQ(alias.entries(0).parameter_number(), 0);
}
// Tests that passing in an exact duplicate input to SetDeviceToHostMeatadata
// is not an error.
TEST_F(XlaCompilerTest, SetDeviceToHostMetadataExactDuplicate) {
XlaCompiler compiler(DefaultOptions());
const string& key = "comm_key";
std::vector<DataType> types{DT_INT32};
std::vector<TensorShape> shapes{TensorShape({2})};
TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes));
TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes));
}
// Tests that passing in a mismatched duplicate input to
// SetDeviceToHostMeatadata is not an error.
TEST_F(XlaCompilerTest, SetDeviceToHostMetadataMismatchedDuplicate) {
XlaCompiler compiler(DefaultOptions());
const string& key = "comm_key";
std::vector<DataType> types{DT_INT32};
std::vector<TensorShape> shapes{TensorShape({2})};
std::vector<DataType> types2{DT_FLOAT};
std::vector<TensorShape> shapes2{TensorShape({1})};
TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes));
Status status = compiler.SetDeviceToHostMetadata(key, types2, shapes2);
EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
}
// Tests that passing in an exact duplicate input to SetHostToDeviceMeatadata
// is not an error.
TEST_F(XlaCompilerTest, SetHostToDeviceMetadataExactDuplicate) {
XlaCompiler compiler(DefaultOptions());
const string& key = "comm_key";
std::vector<DataType> types{DT_INT32};
std::vector<TensorShape> shapes{TensorShape({2})};
TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes));
TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes));
}
// Tests that passing in a mismatched duplicate input to
// SetHostToDeviceMeatadata is not an error.
TEST_F(XlaCompilerTest, SetHostToDeviceMetadataMismatchedDuplicate) {
XlaCompiler compiler(DefaultOptions());
const string& key = "comm_key";
std::vector<DataType> types{DT_INT32};
std::vector<TensorShape> shapes{TensorShape({2})};
std::vector<DataType> types2{DT_FLOAT};
std::vector<TensorShape> shapes2{TensorShape({1})};
TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes));
Status status = compiler.SetHostToDeviceMetadata(key, types2, shapes2);
EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
}
} // namespace
} // namespace tensorflow