Global de-std::unique_ptr cleanup for xla::Literal.

PiperOrigin-RevId: 212313258
This commit is contained in:
A. Unique TensorFlower 2018-09-10 12:33:49 -07:00 committed by TensorFlower Gardener
parent 656b3e9c84
commit dd6d7c5c58
147 changed files with 3785 additions and 4183 deletions
tensorflow/compiler
tf2xla
xla
client
literal.ccliteral.hliteral_test.ccliteral_util.ccliteral_util.hpacked_literal_reader.ccpacked_literal_reader.h
python
reference_util.ccreference_util_test.cc
rpc
service
tests

View File

@ -81,7 +81,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
TF_ASSIGN_OR_RETURN(auto literal,
client->ComputeConstant(constant_graph));
TF_RETURN_IF_ERROR(
LiteralToHostTensor(*literal, arg.type, &arg.constant_value));
LiteralToHostTensor(literal, arg.type, &arg.constant_value));
} else {
arg.kind = XlaCompiler::Argument::kParameter;
}

View File

@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
std::vector<xla::XlaOp> args;
args.push_back(ctx->Input(0));
args.push_back(xla::ConstantLiteral(
&b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
&b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
if (input_shape.dims() > 1) {
// Don't bother passing the output shape and dim for the 1d case, since
// the shape is always a scalar and the dim is always 0.
args.push_back(xla::ConstantLiteral(
&b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
&b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
args.push_back(
xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim)));
xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(dim)));
}
xla::Shape xla_shape =

View File

@ -64,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
xla::Literal literal;
switch (type) {
case xla::U8:
literal = std::move(*xla::LiteralUtil::CreateR0<uint8>(value));
literal = xla::LiteralUtil::CreateR0<uint8>(value);
break;
case xla::U32:
literal = std::move(*xla::LiteralUtil::CreateR0<uint32>(value));
literal = xla::LiteralUtil::CreateR0<uint32>(value);
break;
case xla::U64:
literal = std::move(*xla::LiteralUtil::CreateR0<uint64>(value));
literal = xla::LiteralUtil::CreateR0<uint64>(value);
break;
case xla::S8:
literal = std::move(*xla::LiteralUtil::CreateR0<int8>(value));
literal = xla::LiteralUtil::CreateR0<int8>(value);
break;
case xla::S32:
literal = std::move(*xla::LiteralUtil::CreateR0<int32>(value));
literal = xla::LiteralUtil::CreateR0<int32>(value);
break;
case xla::S64:
literal = std::move(*xla::LiteralUtil::CreateR0<int64>(value));
literal = xla::LiteralUtil::CreateR0<int64>(value);
break;
case xla::F32:
literal = std::move(*xla::LiteralUtil::CreateR0<float>(value));
literal = xla::LiteralUtil::CreateR0<float>(value);
break;
case xla::F64:
literal = std::move(*xla::LiteralUtil::CreateR0<double>(value));
literal = xla::LiteralUtil::CreateR0<double>(value);
break;
case xla::C64:
literal = std::move(*xla::LiteralUtil::CreateR0<complex64>(value));
literal = xla::LiteralUtil::CreateR0<complex64>(value);
break;
case xla::PRED:
LOG(FATAL) << "pred element type is not integral";
@ -96,12 +96,12 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
case xla::U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case xla::BF16:
literal = std::move(
*xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
literal =
xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value));
break;
case xla::F16:
literal = std::move(*xla::LiteralUtil::CreateR0<xla::half>(
static_cast<xla::half>(value)));
literal =
xla::LiteralUtil::CreateR0<xla::half>(static_cast<xla::half>(value));
break;
case xla::TUPLE:
LOG(FATAL) << "tuple element type is not integral";

View File

@ -27,19 +27,17 @@ TEST(LiteralUtil, LiteralToHostTensor) {
// int64 literal can only be converted to an int64 host tensor.
{
std::vector<int64> int64_values = {1, 2, 3};
std::unique_ptr<xla::Literal> int64_values_literal =
xla::Literal int64_values_literal =
xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
Tensor host_tensor;
EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor)
.error_message());
EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32",
LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor)
.error_message());
EXPECT_EQ(
"Cannot convert literal of type S64 to tensor of type qint32",
LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor)
.error_message());
EXPECT_TRUE(
LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor)
.ok());
LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok());
test::ExpectTensorEqual<int64>(host_tensor,
test::AsTensor<int64>(int64_values));
}
@ -48,23 +46,22 @@ TEST(LiteralUtil, LiteralToHostTensor) {
// Repeat tests with int32.
Tensor host_tensor;
std::vector<int32> int32_values = {10, 11};
std::unique_ptr<xla::Literal> int32_values_literal =
xla::Literal int32_values_literal =
xla::LiteralUtil::CreateR1(absl::Span<const int32>(int32_values));
EXPECT_TRUE(
LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
.ok());
LiteralToHostTensor(int32_values_literal, DT_INT32, &host_tensor).ok());
test::ExpectTensorEqual<int32>(host_tensor,
test::AsTensor<int32>(int32_values));
EXPECT_TRUE(
LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor)
LiteralToHostTensor(int32_values_literal, DT_QINT32, &host_tensor)
.ok());
std::vector<qint32> qint32_values = {10, 11};
test::ExpectTensorEqual<qint32>(host_tensor,
test::AsTensor<qint32>(qint32_values));
EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64",
LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor)
LiteralToHostTensor(int32_values_literal, DT_INT64, &host_tensor)
.error_message());
}
}

View File

@ -77,8 +77,8 @@ TEST(ConvertGraphDefToXla, Sum) {
// Set up arguments.
auto x_literal = xla::LiteralUtil::CreateR0<int32>(10);
auto y_literal = xla::LiteralUtil::CreateR0<int32>(32);
auto x_global_or = client->TransferToServer(*x_literal);
auto y_global_or = client->TransferToServer(*y_literal);
auto x_global_or = client->TransferToServer(x_literal);
auto y_global_or = client->TransferToServer(y_literal);
TF_EXPECT_OK(x_global_or.status());
TF_EXPECT_OK(y_global_or.status());
std::unique_ptr<xla::GlobalData> x_global =
@ -90,8 +90,8 @@ TEST(ConvertGraphDefToXla, Sum) {
auto result_or =
client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
TF_EXPECT_OK(result_or.status());
std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
EXPECT_EQ("(s32[]) (\n42\n)", result->ToString());
xla::Literal result = std::move(result_or.ValueOrDie());
EXPECT_EQ("(s32[]) (\n42\n)", result.ToString());
config.mutable_feed(0)->mutable_id()->set_output_index(
123); /* invalid output_index */

View File

@ -208,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) {
std::move(graph), args, &result));
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> param1_literal =
xla::LiteralUtil::CreateR1<int32>({-3, 101});
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
xla::LiteralUtil::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal =
xla::LiteralUtil::MakeTuple({expected0.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation of a graph where the _Retval node is not necessarily last
@ -264,23 +259,20 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) {
args, &result));
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> param1_literal =
xla::LiteralUtil::CreateR1<int32>({-3, 101});
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal));
EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
}
// Tests that the compiler doesn't reorder the parameters.
@ -408,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
EXPECT_FALSE(result.outputs[1].is_constant);
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
xla::Literal actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
xla::LiteralUtil::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected_literal =
xla::LiteralUtil::MakeTuple({expected0.get()});
EXPECT_TRUE(
xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
{
@ -443,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
EXPECT_FALSE(result.outputs[1].is_constant);
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
xla::Literal actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
xla::LiteralUtil::CreateR0<int32>(7);
std::unique_ptr<xla::Literal> expected1 =
xla::LiteralUtil::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected =
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
xla::Literal expected =
xla::LiteralUtil::MakeTuple({&expected0, &expected1});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
}
}
@ -672,34 +657,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
update.tensor_array_gradients_accessed);
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> input_base =
xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> input_grad2 =
xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::Literal> input =
xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()});
xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*input).ConsumeValueOrDie();
client_->TransferToServer(input).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> output_read =
xla::LiteralUtil::CreateR0<int32>(42);
std::unique_ptr<xla::Literal> output_base =
xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> output_grad1 =
xla::LiteralUtil::CreateR1<int32>({0, 1});
std::unique_ptr<xla::Literal> output_grad2 =
xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple(
{output_base.get(), output_grad1.get(), output_grad2.get()});
std::unique_ptr<xla::Literal> expected_literal =
xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
xla::Literal output_resource =
xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
xla::Literal expected_literal =
xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation and execution of a graph that adds two tensors.
@ -866,29 +843,24 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
void RunAndCheckVariablesComputation(
xla::Client* client, const XlaCompiler::CompilationResult& result) {
std::unique_ptr<xla::Literal> param0_literal =
xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> param1_literal =
xla::LiteralUtil::CreateR1<int32>({-3, 101});
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client->TransferToServer(*param0_literal).ConsumeValueOrDie();
client->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
client->TransferToServer(*param1_literal).ConsumeValueOrDie();
client->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
client->Transfer(*actual).ConsumeValueOrDie();
xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
xla::LiteralUtil::CreateR1<int32>({5, 144});
std::unique_ptr<xla::Literal> expected1 =
xla::LiteralUtil::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal =
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144});
xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143});
xla::Literal expected_literal =
xla::LiteralUtil::MakeTuple({&expected0, &expected1});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests a simple graph that reads and writes a variable.
@ -952,20 +924,17 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
std::move(graph), args, &result));
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param1_literal =
xla::LiteralUtil::CreateR1<int32>({-3, 101});
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param1_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected_literal =
xla::LiteralUtil::MakeTuple({});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, ReturnResourceHandle) {
@ -1069,29 +1038,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
xla::Literal param0_literal =
xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
std::unique_ptr<xla::Literal> param1_literal =
xla::Literal param1_literal =
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
xla::Literal expected0 =
xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
std::unique_ptr<xla::Literal> expected1 =
xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
std::unique_ptr<xla::Literal> expected_literal =
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
xla::Literal expected_literal =
xla::LiteralUtil::MakeTuple({&expected0, &expected1});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
@ -1138,29 +1105,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
xla::Literal param0_literal =
xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
std::unique_ptr<xla::Literal> param1_literal =
xla::Literal param1_literal =
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
std::unique_ptr<xla::Literal> expected1 =
xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
std::unique_ptr<xla::Literal> expected_literal =
xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
xla::Literal expected_literal =
xla::LiteralUtil::MakeTuple({&expected0, &expected1});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests a graph which has a function with an invalid op.

View File

@ -213,16 +213,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
context_->op_kernel().name(), " input ", index,
".\nError: ", constant_graph.status().error_message());
}
xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(),
&layout);
xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant(
constant_graph.ValueOrDie(), &layout);
if (!computed.ok()) {
return errors::Internal("Error evaluating ", context_->op_kernel().name(),
" input ", index,
" as a compile-time constant.\nError: ",
computed.status().error_message());
}
*constant_literal = std::move(*computed.ValueOrDie());
*constant_literal = std::move(computed).ValueOrDie();
return Status::OK();
}

View File

@ -37,8 +37,8 @@ Client::Client(ServiceInterface* stub) : stub_(stub) {}
Client::~Client() = default;
StatusOr<std::unique_ptr<Literal>> Client::Transfer(
const GlobalData& data, const Shape* shape_with_layout) {
StatusOr<Literal> Client::Transfer(const GlobalData& data,
const Shape* shape_with_layout) {
TransferToClientRequest request;
*request.mutable_data() = data.handle();
if (shape_with_layout != nullptr) {
@ -114,7 +114,7 @@ Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id,
return Status::OK();
}
StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
StatusOr<Literal> Client::TransferFromOutfeed(
const Shape* shape_with_layout, int64 replica_id,
const DeviceHandle* device_handle) {
TransferFromOutfeedRequest request;
@ -162,7 +162,7 @@ Status Client::ResetDevice() {
return Status::OK();
}
StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
StatusOr<Literal> Client::ExecuteAndTransfer(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options,
ExecutionProfile* execution_profile) {
@ -177,8 +177,8 @@ StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
return Transfer(*data, shape_with_output_layout);
}
StatusOr<std::unique_ptr<Literal>> Client::ComputeConstant(
const XlaComputation& computation, const Layout* output_layout) const {
StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
const Layout* output_layout) const {
ComputeConstantGraphRequest request;
*request.mutable_computation() = computation.proto();
if (output_layout != nullptr) {

View File

@ -96,8 +96,8 @@ class Client {
//
// If shape_with_layout is not nullptr, it points to a shape whose layout will
// be the layout of the returned literal.
StatusOr<std::unique_ptr<Literal>> Transfer(
const GlobalData& data, const Shape* shape_with_layout = nullptr);
StatusOr<Literal> Transfer(const GlobalData& data,
const Shape* shape_with_layout = nullptr);
// Transfer the given literal to the server. This allocates memory on the
// device and copies the literal's contents over. Returns a global data handle
@ -122,7 +122,7 @@ class Client {
// device_handle and replica_id together specify a particular device; a device
// assigned for the given replica_id among the replicas that the given device
// handle belongs to.
StatusOr<std::unique_ptr<Literal>> TransferFromOutfeed(
StatusOr<Literal> TransferFromOutfeed(
const Shape* shape_with_layout, int64 replica_id = 0,
const DeviceHandle* device_handle = nullptr);
@ -132,7 +132,7 @@ class Client {
// Executes the computation with the given arguments and transfers the result
// to the client as a literal. Parameters are defined the same as for
// Execute() and Transfer().
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
StatusOr<Literal> ExecuteAndTransfer(
const XlaComputation& computation,
absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options = nullptr,
@ -153,7 +153,7 @@ class Client {
//
// If output_layout is non-null, then the output of the computation will be
// stored using that layout.
StatusOr<std::unique_ptr<Literal>> ComputeConstant(
StatusOr<Literal> ComputeConstant(
const XlaComputation& computation,
const Layout* output_layout = nullptr) const;

View File

@ -76,7 +76,7 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
Client* client) {
if (DataSizeOfShape(shape) < (1LL << 20)) {
StatusOr<std::unique_ptr<Literal>> literal_status = MakeFakeLiteral(shape);
StatusOr<Literal> literal_status = MakeFakeLiteral(shape);
if (!literal_status.ok()) {
// If we got an Unimplemented error, fall back to making the fake data via
// an on-device computation.
@ -84,7 +84,7 @@ std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
tensorflow::error::UNIMPLEMENTED);
return MakeFakeDataViaDeviceOrDie(shape, client);
}
return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie();
return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie();
}
// If the data is large, generate it on-device.

View File

@ -195,9 +195,8 @@ Status LocalExecutable::RecordArguments(
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
LiteralFromShapedBuffer(*argument));
*hlo_snapshot->add_arguments() = literal->ToProto();
TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument));
*hlo_snapshot->add_arguments() = literal.ToProto();
}
return Status::OK();
}
@ -205,13 +204,12 @@ Status LocalExecutable::RecordArguments(
Status LocalExecutable::RecordResult(const ShapedBuffer* result,
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_result();
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
LiteralFromShapedBuffer(*result));
*hlo_snapshot->mutable_result() = literal->ToProto();
TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result));
*hlo_snapshot->mutable_result() = literal.ToProto();
return Status::OK();
}
StatusOr<std::unique_ptr<Literal>> LocalExecutable::LiteralFromShapedBuffer(
StatusOr<Literal> LocalExecutable::LiteralFromShapedBuffer(
const ShapedBuffer& shaped_buffer) {
TF_ASSIGN_OR_RETURN(auto stream,
backend_->BorrowStream(shaped_buffer.device_ordinal()));
@ -277,7 +275,7 @@ StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
return std::move(scoped_buffer);
}
StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer) {
TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
shaped_buffer.device_ordinal()));
@ -298,13 +296,13 @@ Status LocalClient::TransferToInfeedLocal(const Literal& literal,
literal);
}
StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal(
const Shape& shape, int device_ordinal) {
StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
int device_ordinal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device_ordinal));
auto literal = Literal::CreateFromShape(shape);
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
executor, shape, literal.get()));
executor, shape, &literal));
return std::move(literal);
}

View File

@ -84,8 +84,7 @@ class LocalExecutable {
Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
// Returns a literal containing the contents of the given ShapedBuffer.
StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
const ShapedBuffer& shaped_buffer);
StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
// The ordinal of the device which this executable was compiled for. The
// executable can run on all equivalent devices (as determined by
@ -132,8 +131,7 @@ class LocalClient : public Client {
// Copy the data from the device contained in the given ShapedBuffer and
// return as a Literal.
StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer);
StatusOr<Literal> ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
// Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
// as long as the handle is valid.
@ -151,8 +149,8 @@ class LocalClient : public Client {
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
// not inherit from Client and there is no possibility of confusion with
// Client::TransferFromOutfeed.
StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocal(
const Shape& shape, int device_ordinal);
StatusOr<Literal> TransferFromOutfeedLocal(const Shape& shape,
int device_ordinal);
// Returns the device ordinal that corresponds to the given replica number.
//

View File

@ -738,7 +738,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeNil();
*instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto();
*instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
});
}

View File

@ -2112,12 +2112,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
template <typename NativeT>
XlaOp XlaBuilder::ConstantR0(NativeT value) {
return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
return ConstantLiteral(LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR1(absl::Span<const NativeT> values) {
return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
return ConstantLiteral(LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@ -2129,44 +2129,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
}
inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
return ConstantLiteral(*LiteralUtil::CreateR1(values));
return ConstantLiteral(LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
return ConstantLiteral(LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
*LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
return ConstantLiteral(LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
*LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
return ConstantLiteral(LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
*LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
@ -2189,12 +2189,12 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
template <typename NativeT>
XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@ -2207,13 +2207,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
inline XlaOp ConstantR1(XlaBuilder* builder,
const tensorflow::core::Bitmap& values) {
return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp ConstantR2(XlaBuilder* builder,
std::initializer_list<std::initializer_list<NativeT>> values) {
return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
@ -2221,14 +2221,13 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
builder,
*LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
return ConstantLiteral(builder,
*LiteralUtil::CreateFromArray<NativeT>(values));
LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
@ -2236,15 +2235,14 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
const Array2D<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
builder,
*LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
const Array2D<NativeT>& values) {
return ConstantLiteral(builder,
*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
@ -2253,7 +2251,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
const Layout& layout) {
return ConstantLiteral(
builder,
*LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>

View File

@ -174,9 +174,9 @@ Literal& Literal::operator=(Literal&& other) {
return *this;
}
std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
auto literal = absl::make_unique<Literal>(shape);
literal->root_piece_->ForEachMutableSubpiece(
Literal LiteralBase::CreateFromShape(const Shape& shape) {
Literal literal(shape);
literal.root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (ShapeUtil::IsArray(piece->subshape())) {
memset(piece->untyped_data(), 0, piece->size_bytes());
@ -278,8 +278,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
return Status::OK();
}
/* static */ StatusOr<std::unique_ptr<Literal>>
MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
const LiteralProto& proto) {
if (!proto.has_shape()) {
return InvalidArgument("LiteralProto has no shape");
}
@ -287,9 +287,9 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
return InvalidArgument("LiteralProto has no layout");
}
auto literal = absl::make_unique<Literal>(proto.shape());
Literal literal(proto.shape());
TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
[&](const ShapeIndex& index, Piece* piece) {
const LiteralProto* proto_element = &proto;
for (int64 i : index) {
@ -556,38 +556,37 @@ void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
}
}
std::unique_ptr<Literal> LiteralBase::Relayout(
const Layout& new_layout, const ShapeIndex& shape_index) const {
Literal LiteralBase::Relayout(const Layout& new_layout,
const ShapeIndex& shape_index) const {
// Create new shape with 'new_layout' set at the given shape index.
Shape new_shape = shape();
Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
*subshape->mutable_layout() = new_layout;
auto result = absl::make_unique<Literal>(new_shape);
TF_CHECK_OK(result->CopyFrom(*this));
Literal result(new_shape);
TF_CHECK_OK(result.CopyFrom(*this));
return result;
}
std::unique_ptr<Literal> LiteralBase::Relayout(
const Shape& shape_with_layout) const {
Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
<< "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
<< " not compatible with literal shape "
<< ShapeUtil::HumanString(shape());
std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
Literal result = CreateFromShape(shape_with_layout);
ShapeUtil::ForEachSubshape(
result->shape(),
result.shape(),
[this, &result](const Shape& subshape, const ShapeIndex& index) {
if (ShapeUtil::IsArray(subshape)) {
TF_CHECK_OK(result->CopyFrom(*this,
/*dest_shape_index=*/index,
/*src_shape_index=*/index));
TF_CHECK_OK(result.CopyFrom(*this,
/*dest_shape_index=*/index,
/*src_shape_index=*/index));
}
});
return result;
}
StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
StatusOr<Literal> LiteralBase::Broadcast(
const Shape& result_shape, absl::Span<const int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Broadcast only supports arrays.");
@ -598,14 +597,14 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
result_shape.dimensions(dimensions[i]));
}
std::unique_ptr<Literal> result = absl::make_unique<Literal>(result_shape);
Literal result(result_shape);
// scratch_source_index is temporary storage space for the computed index into
// the input literal. We put it here to avoid allocating an std::vector in
// every iteration of ShapeUtil::ForEachIndex.
std::vector<int64> scratch_source_index(shape().dimensions_size());
char* dest_data = static_cast<char*>(result->untyped_data());
char* dest_data = static_cast<char*>(result.untyped_data());
const char* source_data = static_cast<const char*>(untyped_data());
const int64 primitive_size =
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
@ -627,37 +626,36 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
return std::move(result);
}
StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
StatusOr<Literal> LiteralBase::Reshape(
absl::Span<const int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Reshape does not support tuples.");
}
std::unique_ptr<Literal> output;
Literal output;
if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
output =
Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
} else {
output = CloneToUnique();
output = Clone();
}
// Because the layout is monotonic, we can simply reuse the same sequence of
// values without changing their order.
*output->mutable_shape_do_not_use() =
*output.mutable_shape_do_not_use() =
ShapeUtil::MakeShape(shape().element_type(), dimensions);
int64 elements_before = ShapeUtil::ElementsIn(shape());
int64 elements_after = ShapeUtil::ElementsIn(output->shape());
int64 elements_after = ShapeUtil::ElementsIn(output.shape());
if (elements_before != elements_after) {
return InvalidArgument(
"Shapes before and after Literal::Reshape have different numbers "
"of elements: %s vs %s.",
ShapeUtil::HumanString(shape()),
ShapeUtil::HumanString(output->shape()));
ShapeUtil::HumanString(output.shape()));
}
return std::move(output);
}
std::unique_ptr<Literal> LiteralBase::Transpose(
absl::Span<const int64> permutation) const {
Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
<< "Given permutation is not a permutation of dimension numbers";
@ -687,32 +685,31 @@ std::unique_ptr<Literal> LiteralBase::Transpose(
for (auto index : LayoutUtil::MinorToMajor(shape())) {
layout->add_minor_to_major(inverse_permutation[index]);
}
auto new_literal = absl::make_unique<Literal>(permuted_shape);
DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()),
Literal new_literal(permuted_shape);
DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
ShapeUtil::ByteSizeOf(shape()));
std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
return new_literal;
}
template <typename NativeT>
std::unique_ptr<Literal> LiteralBase::SliceInternal(
Literal LiteralBase::SliceInternal(
const Shape& result_shape, absl::Span<const int64> start_indices) const {
auto result_literal = absl::make_unique<Literal>(result_shape);
Literal result_literal(result_shape);
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
result_literal->EachCell<NativeT>(
result_literal.EachCell<NativeT>(
[&](absl::Span<const int64> indices, NativeT /*value*/) {
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
new_indices[i] = indices[i] + start_indices[i];
}
NativeT value = Get<NativeT>(new_indices);
result_literal->Set<NativeT>(indices, value);
result_literal.Set<NativeT>(indices, value);
});
return result_literal;
}
std::unique_ptr<Literal> LiteralBase::Slice(
absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices) const {
Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices) const {
CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
DimensionVector result_dimensions;
@ -750,12 +747,6 @@ Literal LiteralBase::Clone() const {
return result;
}
std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
auto result = absl::make_unique<Literal>(shape());
TF_CHECK_OK(result->CopyFrom(*this));
return result;
}
string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
@ -1191,14 +1182,14 @@ void LiteralBase::EachCellAsString(
namespace {
template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
const LiteralBase& src_literal, const ConverterType& converter) {
Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
const ConverterType& converter) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
auto result_literal = absl::make_unique<Literal>(ShapeUtil::ChangeElementType(
Literal result_literal(ShapeUtil::ChangeElementType(
src_literal.shape(),
primitive_util::NativeToPrimitiveType<NativeDestT>()));
auto src_data = src_literal.data<NativeSrcT>();
auto dest_data = result_literal->template data<NativeDestT>();
auto dest_data = result_literal.template data<NativeDestT>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
@ -1208,8 +1199,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
}
template <typename NativeSrcT, typename NativeDestT>
std::unique_ptr<Literal> ConvertBetweenNativeTypes(
const LiteralBase& src_literal) {
Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
@ -1217,7 +1207,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
std::unique_ptr<Literal>>::type
Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) {
return tensorflow::bit_cast<NativeDestT>(src);
@ -1232,20 +1222,20 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
// identical sizes higher up.
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
std::unique_ptr<Literal>>::type
Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
LOG(FATAL) << "Invalid bitcast between types of different sizes.";
}
template <PrimitiveType primitive_src_type>
std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
Literal ConvertToC64(const LiteralBase& src_literal) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
auto result_literal = absl::make_unique<Literal>(
Literal result_literal(
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
using NativeSrcT =
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
absl::Span<const NativeSrcT> src_data = src_literal.data<NativeSrcT>();
absl::Span<complex64> dest_data = result_literal->data<complex64>();
absl::Span<complex64> dest_data = result_literal.data<complex64>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
@ -1254,8 +1244,7 @@ std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
}
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
bool bitcast) {
Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
if (bitcast) {
return BitcastBetweenNativeTypes<
@ -1273,9 +1262,9 @@ std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
}
template <PrimitiveType primitive_src_type>
StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
bool bitcast) {
StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
PrimitiveType primitive_dest_type,
bool bitcast) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
case (type): \
@ -1307,12 +1296,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
PrimitiveType_Name(primitive_dest_type));
}
StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
const LiteralBase& literal, PrimitiveType primitive_dest_type,
bool bitcast) {
StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
PrimitiveType primitive_dest_type,
bool bitcast) {
TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
if (literal.shape().element_type() == primitive_dest_type) {
return literal.CloneToUnique();
return literal.Clone();
}
switch (literal.shape().element_type()) {
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
@ -1342,12 +1331,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
} // namespace
StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
StatusOr<Literal> LiteralBase::Convert(
PrimitiveType primitive_dest_type) const {
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
}
StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
StatusOr<Literal> LiteralBase::BitcastConvert(
PrimitiveType primitive_dest_type) const {
if (primitive_util::BitWidth(shape().element_type()) !=
primitive_util::BitWidth(primitive_dest_type)) {
@ -1362,8 +1351,8 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
}
StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
const Shape& dest_shape, bool round_f32_to_bf16) const {
StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape,
bool round_f32_to_bf16) const {
if (!ShapeUtil::IsTuple(dest_shape)) {
if (round_f32_to_bf16 && shape().element_type() == F32 &&
dest_shape.element_type() == BF16) {
@ -1381,11 +1370,9 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
TF_ASSIGN_OR_RETURN(
auto new_element,
element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
elements.push_back(std::move(*new_element));
elements.push_back(std::move(new_element));
}
auto converted = absl::make_unique<Literal>();
*converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
return std::move(converted);
return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
}
/* static */ Literal MutableLiteralBase::MoveIntoTuple(

View File

@ -223,25 +223,21 @@ class LiteralBase {
//
// TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
// the default behavior.
StatusOr<std::unique_ptr<Literal>> ConvertToShape(
const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
StatusOr<Literal> ConvertToShape(const Shape& dest_shape,
bool round_f32_to_bf16 = false) const;
// Converts this literal to another primitive type using a bitcast
// conversion. The to and from primitive types must have the same bit
// width. Returns an error if the conversion is not possible. This literal
// must be array-shaped.
StatusOr<std::unique_ptr<Literal>> BitcastConvert(
PrimitiveType primitive_dest_type) const;
StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
// Converts this literal to another primitive type. Returns an error if the
// conversion is not possible. This literal must be array-shaped.
StatusOr<std::unique_ptr<Literal>> Convert(
PrimitiveType primitive_dest_type) const;
StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
// Clones the underlying buffers into a new Literal, or new
// std::unique_ptr<Literal>.
// Clones the underlying buffers into a new Literal.
Literal Clone() const;
std::unique_ptr<Literal> CloneToUnique() const;
// TODO(b/67651157): The methods below which perform computation on Literals
// (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
@ -259,24 +255,23 @@ class LiteralBase {
// Note: this is useful when the client wants to ensure that a value placed in
// the XLA allocation tracker has a particular layout; for efficiency
// purposes or avoiding unimplemented operation/layout combinations.
std::unique_ptr<Literal> Relayout(const Layout& new_layout,
const ShapeIndex& shape_index = {}) const;
Literal Relayout(const Layout& new_layout,
const ShapeIndex& shape_index = {}) const;
// An overload of Relayout which changes the layout of the entire shape rather
// than being limited to a single array within the shape.
std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
Literal Relayout(const Shape& shape_with_layout) const;
// Creates a new literal by reshaping this literal to have the given
// dimensions. The total number of elements must not change; The
// implementation currently only supports monotonic dim0-major layouts.
// This literal must be an array.
StatusOr<std::unique_ptr<Literal>> Reshape(
absl::Span<const int64> dimensions) const;
StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
// Creates a new literal by broadcasting this literal with `dimensions` to
// yield a literal of shape `result_shape`.
StatusOr<std::unique_ptr<Literal>> Broadcast(
const Shape& result_shape, absl::Span<const int64> dimensions) const;
StatusOr<Literal> Broadcast(const Shape& result_shape,
absl::Span<const int64> dimensions) const;
// Creates a new literal by reordering the dimensions of this literal.
// The given `permutation` must be a permutation of the dimension numbers
@ -285,7 +280,7 @@ class LiteralBase {
// For example, a transpose call on a literal of shape [3 x 8 x 4] and
// `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
// This literal must be an array.
std::unique_ptr<Literal> Transpose(absl::Span<const int64> permutation) const;
Literal Transpose(absl::Span<const int64> permutation) const;
// Creates a sub-array from this literal by extracting the indices
// [start_index, limit_index) of each dimension. The result literal has the
@ -293,15 +288,15 @@ class LiteralBase {
// start_indices and limit_indices must be the rank of the literal, and the
// indices follow the order of the dimensions.
// This literal must be an array.
std::unique_ptr<Literal> Slice(absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices) const;
Literal Slice(absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices) const;
// Creates a literal with a prepended dimension with bound "times"; e.g. a
// f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
// literal replicated four times.
// This literal must be an array.
template <typename NativeT>
std::unique_ptr<Literal> Replicate(int64 times) const;
Literal Replicate(int64 times) const;
// Creates a new Literal object with the shape specified as parameter.
// The content of the literal values is the default value of the primitive
@ -312,7 +307,7 @@ class LiteralBase {
// initialization, then reinitialization. Conside if a call to
// absl::make_unique<Literal>(shape), followed by the call to
// MutableLiteralBase::Populate can be used instead.
static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
static Literal CreateFromShape(const Shape& shape);
protected:
// A data structure representing a subshape at a particular ShapeIndex within
@ -539,8 +534,8 @@ class LiteralBase {
private:
template <typename NativeT>
std::unique_ptr<Literal> SliceInternal(
const Shape& result_shape, absl::Span<const int64> start_indices) const;
Literal SliceInternal(const Shape& result_shape,
absl::Span<const int64> start_indices) const;
};
// Abstract base class representing a mutable literal in XLA.
@ -687,8 +682,7 @@ class MutableLiteralBase : public LiteralBase {
static Literal MoveIntoTuple(absl::Span<Literal> elements);
// Serialize from a proto.
static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
const LiteralProto& proto);
static StatusOr<Literal> CreateFromProto(const LiteralProto& proto);
protected:
// Returns the piece at the given ShapeIndex.
@ -1137,15 +1131,14 @@ void MutableLiteralBase::PopulateWithValue(NativeT value) {
}
template <typename NativeT>
std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
Literal LiteralBase::Replicate(int64 times) const {
DimensionVector bounds = {times};
bounds.reserve(shape().dimensions_size() + 1);
for (int64 bound : shape().dimensions()) {
bounds.push_back(bound);
}
auto literal = absl::make_unique<Literal>(
ShapeUtil::MakeShape(shape().element_type(), bounds));
int64 elements = ShapeUtil::ElementsIn(literal->shape());
Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
int64 elements = ShapeUtil::ElementsIn(literal.shape());
if (elements == 0) {
return literal;
}
@ -1157,7 +1150,7 @@ std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
bool done = false;
while (!done) {
const auto element = Get<NativeT>(input_indices);
literal->Set<NativeT>(output_indices, element);
literal.Set<NativeT>(output_indices, element);
done = true;
for (int n = 0; n < output_indices.size(); ++n) {

File diff suppressed because it is too large Load Diff

View File

@ -45,7 +45,7 @@ using absl::StrCat;
// Return a literal with all arrays of type FromNativeT converted to type
// ToNativeT in the given literal.
template <typename FromNativeT, typename ToNativeT>
std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
Literal ConvertType(LiteralSlice literal) {
// First construct shape of the result.
Shape result_shape(literal.shape());
ShapeUtil::ForEachMutableSubshape(
@ -56,7 +56,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
primitive_util::NativeToPrimitiveType<ToNativeT>());
}
});
auto result = absl::make_unique<Literal>(result_shape);
Literal result(result_shape);
// Then copy over the data from 'literal' converting FromNativeT values to
// ToNativeT values as necessary.
@ -67,14 +67,14 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
if (subshape.element_type() ==
primitive_util::NativeToPrimitiveType<FromNativeT>()) {
auto src = literal.data<FromNativeT>(shape_index);
auto dest = result->data<ToNativeT>(shape_index);
auto dest = result.data<ToNativeT>(shape_index);
for (int64 i = 0; i < src.size(); ++i) {
dest[i] = static_cast<ToNativeT>(src[i]);
}
} else {
TF_CHECK_OK(result->CopyFrom(literal,
/*dest_shape_index=*/shape_index,
/*src_shape_index=*/shape_index));
TF_CHECK_OK(result.CopyFrom(literal,
/*dest_shape_index=*/shape_index,
/*src_shape_index=*/shape_index));
}
}
});
@ -83,53 +83,52 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
} // namespace
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromDimensions(
/* static */ Literal LiteralUtil::CreateFromDimensions(
PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
return Literal::CreateFromShape(
ShapeUtil::MakeShape(primitive_type, dimensions));
}
/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertBF16ToF32(
/* static */ Literal LiteralUtil::ConvertBF16ToF32(
const LiteralSlice& bf16_literal) {
return ConvertType<bfloat16, float>(bf16_literal);
}
/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertF32ToBF16(
/* static */ Literal LiteralUtil::ConvertF32ToBF16(
const LiteralSlice& f32_literal) {
return ConvertType<float, bfloat16>(f32_literal);
}
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateToken() {
return absl::make_unique<Literal>(ShapeUtil::MakeTokenShape());
/* static */ Literal LiteralUtil::CreateToken() {
return Literal(ShapeUtil::MakeTokenShape());
}
/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
return std::move(*LiteralUtil::CreateR0<uint8>(0));
return LiteralUtil::CreateR0<uint8>(0);
case U32:
return std::move(*LiteralUtil::CreateR0<uint32>(0));
return LiteralUtil::CreateR0<uint32>(0);
case U64:
return std::move(*LiteralUtil::CreateR0<uint64>(0));
return LiteralUtil::CreateR0<uint64>(0);
case S8:
return std::move(*LiteralUtil::CreateR0<int8>(0));
return LiteralUtil::CreateR0<int8>(0);
case S32:
return std::move(*LiteralUtil::CreateR0<int32>(0));
return LiteralUtil::CreateR0<int32>(0);
case S64:
return std::move(*LiteralUtil::CreateR0<int64>(0));
return LiteralUtil::CreateR0<int64>(0);
case F16:
return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(0.0f)));
return LiteralUtil::CreateR0<half>(static_cast<half>(0.0f));
case BF16:
return std::move(
*LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
case F32:
return std::move(*LiteralUtil::CreateR0<float>(0));
return LiteralUtil::CreateR0<float>(0);
case F64:
return std::move(*LiteralUtil::CreateR0<double>(0));
return LiteralUtil::CreateR0<double>(0);
case C64:
return std::move(*LiteralUtil::CreateR0<complex64>(0));
return LiteralUtil::CreateR0<complex64>(0);
case PRED:
return std::move(*LiteralUtil::CreateR0<bool>(false));
return LiteralUtil::CreateR0<bool>(false);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
@ -145,30 +144,29 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
return std::move(*LiteralUtil::CreateR0<uint8>(1));
return LiteralUtil::CreateR0<uint8>(1);
case U32:
return std::move(*LiteralUtil::CreateR0<uint32>(1));
return LiteralUtil::CreateR0<uint32>(1);
case U64:
return std::move(*LiteralUtil::CreateR0<uint64>(1));
return LiteralUtil::CreateR0<uint64>(1);
case S8:
return std::move(*LiteralUtil::CreateR0<int8>(1));
return LiteralUtil::CreateR0<int8>(1);
case S32:
return std::move(*LiteralUtil::CreateR0<int32>(1));
return LiteralUtil::CreateR0<int32>(1);
case S64:
return std::move(*LiteralUtil::CreateR0<int64>(1));
return LiteralUtil::CreateR0<int64>(1);
case F16:
return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(1.0f)));
return LiteralUtil::CreateR0<half>(static_cast<half>(1.0f));
case BF16:
return std::move(
*LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
case F32:
return std::move(*LiteralUtil::CreateR0<float>(1));
return LiteralUtil::CreateR0<float>(1);
case F64:
return std::move(*LiteralUtil::CreateR0<double>(1));
return LiteralUtil::CreateR0<double>(1);
case C64:
return std::move(*LiteralUtil::CreateR0<complex64>(1));
return LiteralUtil::CreateR0<complex64>(1);
case PRED:
return std::move(*LiteralUtil::CreateR0<bool>(true));
return LiteralUtil::CreateR0<bool>(true);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
@ -184,42 +182,36 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
return std::move(
*LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
case U32:
return std::move(
*LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
case U64:
return std::move(
*LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
case S8:
return std::move(
*LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min()));
return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
case S32:
return std::move(
*LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min()));
return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
case S64:
return std::move(
*LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min()));
return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min());
case F32:
return std::move(*LiteralUtil::CreateR0<float>(
-std::numeric_limits<float>::infinity()));
return LiteralUtil::CreateR0<float>(
-std::numeric_limits<float>::infinity());
case F64:
return std::move(*LiteralUtil::CreateR0<double>(
-std::numeric_limits<double>::infinity()));
return LiteralUtil::CreateR0<double>(
-std::numeric_limits<double>::infinity());
case C64:
LOG(FATAL) << "C64 element type has no minimum value";
case PRED:
return std::move(*LiteralUtil::CreateR0<bool>(false));
return LiteralUtil::CreateR0<bool>(false);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
return std::move(*LiteralUtil::CreateR0<half>(
static_cast<half>(-std::numeric_limits<float>::infinity())));
return LiteralUtil::CreateR0<half>(
static_cast<half>(-std::numeric_limits<float>::infinity()));
case BF16:
return std::move(*LiteralUtil::CreateR0<bfloat16>(
static_cast<bfloat16>(-std::numeric_limits<float>::infinity())));
return LiteralUtil::CreateR0<bfloat16>(
static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no minimum value";
case OPAQUE:
@ -232,40 +224,34 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
return std::move(
*LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
case U32:
return std::move(
*LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
case U64:
return std::move(
*LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
case S8:
return std::move(
*LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max()));
return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
case S32:
return std::move(
*LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max()));
return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
case S64:
return std::move(
*LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max()));
return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max());
case F32:
return std::move(*LiteralUtil::CreateR0<float>(
std::numeric_limits<float>::infinity()));
return LiteralUtil::CreateR0<float>(
std::numeric_limits<float>::infinity());
case F64:
return std::move(*LiteralUtil::CreateR0<double>(
std::numeric_limits<double>::infinity()));
return LiteralUtil::CreateR0<double>(
std::numeric_limits<double>::infinity());
case PRED:
return std::move(*LiteralUtil::CreateR0<bool>(true));
return LiteralUtil::CreateR0<bool>(true);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
return std::move(*LiteralUtil::CreateR0<half>(
static_cast<half>(std::numeric_limits<float>::infinity())));
return LiteralUtil::CreateR0<half>(
static_cast<half>(std::numeric_limits<float>::infinity()));
case BF16:
return std::move(*LiteralUtil::CreateR0<bfloat16>(
static_cast<bfloat16>(std::numeric_limits<float>::infinity())));
return LiteralUtil::CreateR0<bfloat16>(
static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no maximum value";
case OPAQUE:
@ -275,31 +261,29 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
}
}
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
/* static */ Literal LiteralUtil::CreateR1(
const tensorflow::core::Bitmap& values) {
auto literal = absl::make_unique<Literal>(
Literal literal(
ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
literal->PopulateR1(values);
literal.PopulateR1(values);
return literal;
}
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1U8(
absl::string_view value) {
auto literal = absl::make_unique<Literal>(
ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
/* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
Literal literal(ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
for (int i = 0; i < value.size(); ++i) {
literal->Set<uint8>({i}, value[i]);
literal.Set<uint8>({i}, value[i]);
}
return literal;
}
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2F32Linspace(
float from, float to, int64 rows, int64 cols) {
/* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
int64 rows, int64 cols) {
auto value = MakeLinspaceArray2D(from, to, rows, cols);
return CreateR2FromArray2D(*value);
}
/* static */ std::unique_ptr<Literal> LiteralUtil::ReshapeSlice(
/* static */ Literal LiteralUtil::ReshapeSlice(
absl::Span<const int64> new_dimensions,
absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
int64 new_num_elements = 1;
@ -309,13 +293,13 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
CHECK_EQ(new_dimensions.size(), minor_to_major.size());
auto new_literal = absl::make_unique<Literal>(
Literal new_literal(
ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
// Create a new shape with the given minor-to-major layout. This shape is used
// solely for converting linear address to multi-dimensional addresses when
// writing elements to the new literal.
Shape shape_with_layout = new_literal->shape();
Shape shape_with_layout = new_literal.shape();
*shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
// Copy data into new literal, element-by-element.
@ -326,40 +310,40 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
switch (literal.shape().element_type()) {
case PRED:
new_literal->Set<bool>(to_multi_index,
literal.Get<bool>(from_multi_index));
new_literal.Set<bool>(to_multi_index,
literal.Get<bool>(from_multi_index));
break;
case U8:
new_literal->Set<uint8>(to_multi_index,
literal.Get<uint8>(from_multi_index));
new_literal.Set<uint8>(to_multi_index,
literal.Get<uint8>(from_multi_index));
break;
case U32:
new_literal->Set<uint32>(to_multi_index,
literal.Get<uint32>(from_multi_index));
new_literal.Set<uint32>(to_multi_index,
literal.Get<uint32>(from_multi_index));
break;
case S32:
new_literal->Set<int32>(to_multi_index,
literal.Get<int32>(from_multi_index));
new_literal.Set<int32>(to_multi_index,
literal.Get<int32>(from_multi_index));
break;
case U64:
new_literal->Set<uint64>(to_multi_index,
literal.Get<uint64>(from_multi_index));
new_literal.Set<uint64>(to_multi_index,
literal.Get<uint64>(from_multi_index));
break;
case S64:
new_literal->Set<int64>(to_multi_index,
literal.Get<int64>(from_multi_index));
new_literal.Set<int64>(to_multi_index,
literal.Get<int64>(from_multi_index));
break;
case F32:
new_literal->Set<float>(to_multi_index,
literal.Get<float>(from_multi_index));
new_literal.Set<float>(to_multi_index,
literal.Get<float>(from_multi_index));
break;
case F64:
new_literal->Set<double>(to_multi_index,
literal.Get<double>(from_multi_index));
new_literal.Set<double>(to_multi_index,
literal.Get<double>(from_multi_index));
break;
case C64:
new_literal->Set<complex64>(to_multi_index,
literal.Get<complex64>(from_multi_index));
new_literal.Set<complex64>(to_multi_index,
literal.Get<complex64>(from_multi_index));
break;
default:
LOG(FATAL) << "Unhandled primitive element type: "
@ -376,97 +360,82 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
switch (literal.shape().element_type()) {
case PRED:
return std::move(
*LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>()));
return LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>());
// 8 bit types.
case S8:
return std::move(
*LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>()));
return LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>());
case U8:
return std::move(
*LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>()));
return LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>());
// 16 bit types.
case BF16:
return std::move(*LiteralUtil::CreateR0<bfloat16>(
literal.GetFirstElement<bfloat16>()));
return LiteralUtil::CreateR0<bfloat16>(
literal.GetFirstElement<bfloat16>());
case F16:
return std::move(
*LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>()));
return LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>());
case S16:
return std::move(
*LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>()));
return LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>());
case U16:
return std::move(
*LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>()));
return LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>());
// 32 bit types.
case F32:
return std::move(
*LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>()));
return LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>());
case S32:
return std::move(
*LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>()));
return LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>());
case U32:
return std::move(
*LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>()));
return LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>());
// 64 bit types.
case C64:
return std::move(*LiteralUtil::CreateR0<complex64>(
literal.GetFirstElement<complex64>()));
return LiteralUtil::CreateR0<complex64>(
literal.GetFirstElement<complex64>());
case F64:
return std::move(
*LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>()));
return LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>());
case S64:
return std::move(
*LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>()));
return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
case U64:
return std::move(
*LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>()));
return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
default:
LOG(FATAL) << "Unhandled primitive type "
<< literal.shape().element_type();
}
}
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTuple(
/* static */ Literal LiteralUtil::MakeTuple(
absl::Span<const Literal* const> elements) {
std::vector<Shape> element_shapes;
for (const auto* element : elements) {
element_shapes.push_back(element->shape());
}
auto literal =
absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int i = 0; i < elements.size(); ++i) {
TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
}
return literal;
}
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleFromSlices(
/* static */ Literal LiteralUtil::MakeTupleFromSlices(
absl::Span<const LiteralSlice> elements) {
std::vector<Shape> element_shapes;
for (const auto& element : elements) {
element_shapes.push_back(element.shape());
}
auto literal =
absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int i = 0; i < elements.size(); ++i) {
TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i}));
TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
}
return literal;
}
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleOwned(
std::vector<std::unique_ptr<Literal>> elements) {
/* static */ Literal LiteralUtil::MakeTupleOwned(
std::vector<Literal> elements) {
std::vector<Shape> element_shapes;
element_shapes.reserve(elements.size());
for (const auto& element : elements) {
element_shapes.push_back(element->shape());
element_shapes.push_back(element.shape());
}
auto literal =
absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int64 i = 0; i < elements.size(); ++i) {
TF_CHECK_OK(
literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i}));
literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
}
return literal;
}

View File

@ -69,36 +69,34 @@ class LiteralUtil {
// The variants not ending with WithLayout use the default XLA layout for the
// literal's linear representation in memory.
template <typename NativeT>
static std::unique_ptr<Literal> CreateR0(NativeT value);
static Literal CreateR0(NativeT value);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR1(absl::Span<const NativeT> values);
static std::unique_ptr<Literal> CreateR1(
const tensorflow::core::Bitmap& values);
static Literal CreateR1(absl::Span<const NativeT> values);
static Literal CreateR1(const tensorflow::core::Bitmap& values);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR2(
static Literal CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR2WithLayout(
static Literal CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR3(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
values);
static Literal CreateR3(std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
values);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR3WithLayout(
static Literal CreateR3WithLayout(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR4(
static Literal CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR4WithLayout(
static Literal CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
@ -139,9 +137,10 @@ class LiteralUtil {
// [9, 10, 11]: 4.0
//
template <typename NativeT>
static std::unique_ptr<Literal> CreateSparse(
absl::Span<const int64> dimensions, SparseIndexArray indices,
absl::Span<const NativeT> values, bool sort = true);
static Literal CreateSparse(absl::Span<const int64> dimensions,
SparseIndexArray indices,
absl::Span<const NativeT> values,
bool sort = true);
// Creates a scalar literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type);
@ -155,130 +154,120 @@ class LiteralUtil {
static Literal MaxValue(PrimitiveType primitive_type);
// Creates a literal of the given shape where each element is `value`.
template <typename NativeT>
static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
static Literal CreateFullWithDescendingLayout(
absl::Span<const int64> dimensions, NativeT value);
// Creates a new literal from an Array type. The variants not ending with
// WithLayout use the default XLA layout for the literal's linear
// representation in memory.
template <typename NativeT>
static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
static Literal CreateFromArray(const Array<NativeT>& values);
template <typename NativeT>
static std::unique_ptr<Literal> CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout);
static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR2FromArray2D(
const Array2D<NativeT>& values);
static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout);
static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
const Layout& layout);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR3FromArray3D(
const Array3D<NativeT>& values);
static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout);
static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
const Layout& layout);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR4FromArray4D(
const Array4D<NativeT>& values);
static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
const Array4D<NativeT>& values, const Layout& layout);
static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
const Layout& layout);
// Creates a new vector of U8s literal value from a string.
static std::unique_ptr<Literal> CreateR1U8(absl::string_view value);
static Literal CreateR1U8(absl::string_view value);
// Creates a linspace-populated literal with the given number of rows and
// columns.
static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
int64 rows, int64 cols);
static Literal CreateR2F32Linspace(float from, float to, int64 rows,
int64 cols);
// Creates a literal that projects the (x, y) dimensions given in values into
// the z dimension given by "projection".
template <typename NativeT>
static std::unique_ptr<Literal> CreateR3Projected(
static Literal CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection);
// Creates a literal that projects the (x, y) dimensions given in values into
// the z and p dimensions given.
template <typename NativeT>
static std::unique_ptr<Literal> CreateR4Projected(
static Literal CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z);
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
static Literal MakeIdentityR2(int64 size);
// Returns a tuple literal composed of given literals. Data is copied from the
// given elements into the returned literal.
static std::unique_ptr<Literal> MakeTuple(
absl::Span<const Literal* const> elements);
static Literal MakeTuple(absl::Span<const Literal* const> elements);
static std::unique_ptr<Literal> MakeTupleFromSlices(
absl::Span<const LiteralSlice> elements);
static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
// As above, but intended to be invoked with move semantics; i.e.
//
// std::vector<std::unique_ptr<Literal>> elements = ...;
// std::vector<Literal> elements = ...;
// auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
//
// This would have been declared as an overload, but there is ambiguity
// in invocation between the above signature and this one.
static std::unique_ptr<Literal> MakeTupleOwned(
std::vector<std::unique_ptr<Literal>> elements);
static Literal MakeTupleOwned(std::vector<Literal> elements);
// This overload lets you pass a braced list of unique_ptr<Literal>s to
// This overload lets you pass a braced list of Literals to
// MakeTupleOwned:
//
// LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
//
// Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
// Simply relying on the MakeTupleOwned(std::vector<Literal>)
// overload doesn't work because std::initializer_list's elements are always
// const.
//
// The arguments to this function must all be unique_ptr<Literal>.
// The arguments to this function must all be Literal.
template <typename... Ts>
static std::unique_ptr<Literal> MakeTupleOwned(
std::unique_ptr<Ts>... elements) {
std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{
std::move(elements)...};
std::vector<std::unique_ptr<Literal>> v;
static Literal MakeTupleOwned(Ts... elements) {
std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
std::vector<Literal> v;
v.insert(v.begin(), std::make_move_iterator(arr.begin()),
std::make_move_iterator(arr.end()));
return MakeTupleOwned(std::move(v));
}
// Create a constant token literal. Token types have no value.
static std::unique_ptr<Literal> CreateToken();
static Literal CreateToken();
// Creates a new Literal object with its values havings the primitive_type
// type, and with dimensions defined by the dimensions parameter.
// The content of the literal values is the default value of the primitive
// type of literal itself (0 for numeric types, and false for predicates).
static std::unique_ptr<Literal> CreateFromDimensions(
PrimitiveType primitive_type, absl::Span<const int64> dimensions);
static Literal CreateFromDimensions(PrimitiveType primitive_type,
absl::Span<const int64> dimensions);
// If the given literal's data type is bfloat16, converts it to a float
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
static std::unique_ptr<Literal> ConvertBF16ToF32(
const LiteralSlice& bf16_literal);
static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
// If the given literal's data type is float, converts it to a bfloat16
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
static std::unique_ptr<Literal> ConvertF32ToBF16(
const LiteralSlice& f32_literal);
static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
// Creates a literal with a new shape with the given new dimensions using the
// data in the given input literal. For reshaping purposes the (flat) data
// buffer of the input literal is assumed to have the given minor_to_major
// layout order.
static std::unique_ptr<Literal> ReshapeSlice(
absl::Span<const int64> new_dimensions,
absl::Span<const int64> minor_to_major, const LiteralSlice& literal);
static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
absl::Span<const int64> minor_to_major,
const LiteralSlice& literal);
// Creates a literal with the supplied shape, and uses the provided value
// generator to populate the literal's values.
@ -286,7 +275,7 @@ class LiteralUtil {
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
static StatusOr<Literal> CreateRandomLiteral(
const Shape& shape,
const std::function<T(absl::Span<const int64>)>& generator);
@ -297,8 +286,8 @@ class LiteralUtil {
template <
PrimitiveType type, typename E,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
const Shape& shape, E* engine, T mean, T stddev);
static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
T mean, T stddev);
// Creates a literal with the supplied shape, and initializes the literal
// values using a normal distribution with given mean and stddev standard
@ -307,8 +296,8 @@ class LiteralUtil {
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
const Shape& shape, T mean, T stddev);
static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
T stddev);
//
// End of factory methods.
@ -322,44 +311,43 @@ class LiteralUtil {
std::ostream& operator<<(std::ostream& out, const Literal& literal);
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) {
auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShape(
/* static */ Literal LiteralUtil::CreateR0(NativeT value) {
Literal literal(ShapeUtil::MakeShape(
primitive_util::NativeToPrimitiveType<NativeT>(), {}));
literal->Set({}, value);
literal.Set({}, value);
return literal;
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
absl::Span<const NativeT> values) {
auto literal = absl::make_unique<Literal>(
/* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
Literal literal(
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size())}));
literal->PopulateR1(values);
literal.PopulateR1(values);
return literal;
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2WithLayout(
/* static */ Literal LiteralUtil::CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout) {
auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
Literal literal(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size()),
static_cast<int64>(values.begin()->size())},
AsInt64Slice(layout.minor_to_major())));
literal->PopulateR2(values);
literal.PopulateR2(values);
return literal;
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2(
/* static */ Literal LiteralUtil::CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3WithLayout(
/* static */ Literal LiteralUtil::CreateR3WithLayout(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout) {
@ -384,14 +372,14 @@ template <typename NativeT>
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3(
/* static */ Literal LiteralUtil::CreateR3(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values) {
return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4WithLayout(
/* static */ Literal LiteralUtil::CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
@ -422,23 +410,22 @@ template <typename NativeT>
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateSparse(
/* static */ Literal LiteralUtil::CreateSparse(
absl::Span<const int64> dimensions, SparseIndexArray indices,
absl::Span<const NativeT> values, bool sort) {
int64 num_elements = values.size();
int64 rank = dimensions.size();
CHECK_EQ(num_elements, indices.index_count());
CHECK_EQ(rank, indices.rank());
auto literal =
absl::make_unique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
indices.max_indices()));
literal->PopulateSparse(indices, values, sort);
Literal literal(ShapeUtil::MakeShapeWithSparseLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
indices.max_indices()));
literal.PopulateSparse(indices, values, sort);
return literal;
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4(
/* static */ Literal LiteralUtil::CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values) {
@ -446,50 +433,48 @@ template <typename NativeT>
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArrayWithLayout(
/* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
Literal literal(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
AsInt64Slice(layout.minor_to_major())));
literal->PopulateFromArray(values);
literal.PopulateFromArray(values);
return literal;
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArray(
/* static */ Literal LiteralUtil::CreateFromArray(
const Array<NativeT>& values) {
return CreateFromArrayWithLayout(
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal>
LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
const Layout& layout) {
/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2FromArray2D(
/* static */ Literal LiteralUtil::CreateR2FromArray2D(
const Array2D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal>
LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
const Layout& layout) {
/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3FromArray3D(
/* static */ Literal LiteralUtil::CreateR3FromArray3D(
const Array3D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3Projected(
/* static */ Literal LiteralUtil::CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection) {
int64 dim0_size = projection;
@ -514,7 +499,7 @@ template <typename NativeT>
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4Projected(
/* static */ Literal LiteralUtil::CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z) {
int64 dim0_size = projection_p;
@ -542,21 +527,20 @@ template <typename NativeT>
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4FromArray4D(
/* static */ Literal LiteralUtil::CreateR4FromArray4D(
const Array4D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal>
LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
const Layout& layout) {
/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
const Array4D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeIdentityR2(int64 size) {
/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) {
Array2D<NativeT> array(size, size, 0);
for (int64 i = 0; i < size; ++i) {
array(i, i) = 1;
@ -565,33 +549,29 @@ template <typename NativeT>
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal>
LiteralUtil::CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,
NativeT value) {
auto literal =
absl::make_unique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
literal->PopulateWithValue(value);
/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
absl::Span<const int64> dimensions, NativeT value) {
Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
literal.PopulateWithValue(value);
return literal;
}
template <PrimitiveType type, typename T>
/* static */ StatusOr<std::unique_ptr<Literal>>
LiteralUtil::CreateRandomLiteral(
/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
const Shape& shape,
const std::function<T(absl::Span<const int64>)>& generator) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
TF_RET_CHECK(shape.element_type() == type);
auto literal = absl::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
Literal literal(shape);
TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
[&](absl::Span<const int64> indexes) { return generator(indexes); }));
return std::move(literal);
}
template <PrimitiveType type, typename E, typename T>
/* static */ StatusOr<std::unique_ptr<Literal>>
LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
T stddev) {
/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
const Shape& shape, E* engine, T mean, T stddev) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
std::normal_distribution<NativeT> generator(mean, stddev);
return CreateRandomLiteral<type, NativeT>(
@ -600,8 +580,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
}
template <PrimitiveType type, typename T>
/* static */ StatusOr<std::unique_ptr<Literal>>
LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
const Shape& shape, T mean, T stddev) {
std::minstd_rand0 engine;
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
}

View File

@ -39,8 +39,8 @@ PackedLiteralReader::PackedLiteralReader(tensorflow::RandomAccessFile* file)
PackedLiteralReader::~PackedLiteralReader() { delete file_; }
StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
const Shape& shape, const Layout* layout) {
StatusOr<Literal> PackedLiteralReader::Read(const Shape& shape,
const Layout* layout) {
VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape)
<< " layout: "
<< (layout == nullptr ? "<none>" : layout->ShortDebugString());
@ -57,11 +57,11 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
PrimitiveType_Name(shape.element_type()));
}
auto result = absl::make_unique<Literal>(literal_shape);
result->PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
Literal result(literal_shape);
result.PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
int64 elements = ShapeUtil::ElementsIn(shape);
absl::Span<const float> field = result->data<float>();
absl::Span<const float> field = result.data<float>();
char* data = absl::bit_cast<char*>(field.data());
uint64 bytes = elements * sizeof(float);
absl::string_view sp;

View File

@ -41,8 +41,7 @@ class PackedLiteralReader {
//
// Layout is optional. If it is not provided, no layout is set on the literal
// that is produced.
StatusOr<std::unique_ptr<Literal>> Read(const Shape& shape,
const Layout* layout = nullptr);
StatusOr<Literal> Read(const Shape& shape, const Layout* layout = nullptr);
// Returns whether the input file has been fully exhausted; i.e. all available
// packed literals have been read and we're at the end of the file.

View File

@ -81,8 +81,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal,
return client->TransferToInfeedLocal(literal, device_ordinal);
}
StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocalReplica(
const Shape& shape, int replica_number) {
StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
int replica_number) {
VLOG(1) << "Outfeeding literal from replica number: " << replica_number
<< " shape: " << shape;
LocalClient* client = GetOrCreateLocalClient();
@ -141,9 +141,8 @@ StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
LocalClient* client = GetOrCreateLocalClient();
StatusOr<ScopedShapedBuffer> buf = [&] {
if (shape_with_layout) {
std::unique_ptr<Literal> relaid =
argument.Relayout(shape_with_layout.value());
return ToBuffer(client, /*device_ordinal=*/0, *relaid);
Literal relaid = argument.Relayout(shape_with_layout.value());
return ToBuffer(client, /*device_ordinal=*/0, relaid);
}
return ToBuffer(client, /*device_ordinal=*/0, argument);
}();
@ -151,7 +150,7 @@ StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
return new LocalShapedBuffer(std::move(buf).ValueOrDie());
}
StatusOr<std::unique_ptr<Literal>> LocalShapedBuffer::ToLiteral() const {
StatusOr<Literal> LocalShapedBuffer::ToLiteral() const {
LocalClient* client = GetOrCreateLocalClient();
return client->ShapedBufferToLiteral(*shaped_buffer());
}
@ -160,7 +159,7 @@ CompiledLocalComputation::CompiledLocalComputation(
std::unique_ptr<LocalExecutable> executable)
: executable_(std::move(executable)) {}
StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
StatusOr<Literal> CompiledLocalComputation::Execute(
const std::vector<Literal>& arguments,
const std::vector<absl::optional<Shape>>& shapes_with_layout) {
LocalClient* client = GetOrCreateLocalClient();
@ -169,7 +168,7 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
// Each replica populates a StatusOr result, but only replica zero actually
// retrieves its literal value.
std::vector<StatusOr<std::unique_ptr<Literal>>> results(GetReplicaCount());
std::vector<StatusOr<Literal>> results(GetReplicaCount());
{
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
GetReplicaCount());
@ -198,9 +197,8 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
StatusOr<ScopedShapedBuffer> pushed;
if (shape_with_layout) {
std::unique_ptr<Literal> relaid =
argument.Relayout(shape_with_layout.value());
pushed = ToBuffer(client, device_ordinal, *relaid);
Literal relaid = argument.Relayout(shape_with_layout.value());
pushed = ToBuffer(client, device_ordinal, relaid);
} else {
pushed = ToBuffer(client, device_ordinal, argument);
}

View File

@ -51,8 +51,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number);
// Transfers a literal of the given shape from the outfeed of the given replica.
//
// The replica number is resolved to an appropriate device ordinal.
StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
const Shape& shape, int replica_number);
StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
int replica_number);
// Wraps a ScopedShapedBuffer produced by copying a literal "to
// device," i.e. copying a literal to a scoped buffer via the local
@ -65,7 +65,7 @@ class LocalShapedBuffer {
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer);
const ScopedShapedBuffer* shaped_buffer() const;
StatusOr<std::unique_ptr<Literal> > ToLiteral() const;
StatusOr<Literal> ToLiteral() const;
// Transfers ownership of the encapsulated ShapedBuffer to the caller,
// analogous to std::unique_ptr::release().
@ -117,7 +117,7 @@ class CompiledLocalComputation {
// with optionally-specified argument layouts. The literals will be
// re-laid out according to the corresponding elements of
// shapes_with_layout.
StatusOr<std::unique_ptr<Literal> > Execute(
StatusOr<Literal> Execute(
const std::vector<Literal>& arguments,
const std::vector<absl::optional<Shape> >& shapes_with_layout);

View File

@ -216,9 +216,9 @@ tensorflow::ImportNumpy();
}
%typemap(out) StatusOr< std::unique_ptr<Literal> > {
%typemap(out) StatusOr<Literal> {
if ($1.ok()) {
std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();
Literal value = $1.ConsumeValueOrDie();
$result = numpy::PyObjectFromXlaLiteral(*value);
} else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
@ -346,25 +346,25 @@ tensorflow::ImportNumpy();
// Literal
%typemap(in) const Literal& (StatusOr< std::unique_ptr<Literal> > literal_status) {
%typemap(in) const Literal& (StatusOr<Literal> literal_status) {
literal_status = numpy::XlaLiteralFromPyObject($input);
if (!literal_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
SWIG_fail;
}
$1 = literal_status.ValueOrDie().get();
$1 = &literal_status.ValueOrDie();
}
%typemap(out) std::unique_ptr<Literal> {
%typemap(out) Literal {
$result = numpy::PyObjectFromXlaLiteral(*$1);
}
%typemap(out) StatusOr< std::unique_ptr<Literal> > {
%typemap(out) StatusOr<Literal> {
if (!$1.ok()) {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
SWIG_fail;
}
$result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie());
$result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
}
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
@ -375,13 +375,13 @@ tensorflow::ImportNumpy();
const int size = PySequence_Size($input);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
StatusOr< std::unique_ptr<Literal> > literal_status = numpy::XlaLiteralFromPyObject(o);
StatusOr<Literal> literal_status = numpy::XlaLiteralFromPyObject(o);
if (!literal_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
Py_DECREF(o);
SWIG_fail;
}
temps.push_back(std::move(*literal_status.ConsumeValueOrDie()));
temps.push_back(literal_status.ConsumeValueOrDie());
Py_DECREF(o);
}
$1 = &temps;

View File

@ -368,10 +368,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
}
}
StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o) {
if (PyTuple_Check(o)) {
int num_elements = PyTuple_Size(o);
std::vector<std::unique_ptr<Literal>> elements;
std::vector<Literal> elements;
elements.reserve(num_elements);
for (int i = 0; i < num_elements; i++) {
PyObject* element = PyTuple_GetItem(o, i);
@ -389,8 +389,7 @@ StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
int np_type = PyArray_TYPE(py_array);
auto literal = LiteralUtil::CreateFromDimensions(
NumpyTypeToPrimitiveType(np_type), dimensions);
TF_RETURN_IF_ERROR(
CopyNumpyArrayToLiteral(np_type, py_array, literal.get()));
TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal));
return std::move(literal);
} else {
return InvalidArgument(

View File

@ -82,7 +82,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal);
// To avoid transferring ownership of the data buffers that underlie
// PyArrays and XLA literals, this function makes deep copies of all
// array data.
StatusOr<std::unique_ptr<Literal> > XlaLiteralFromPyObject(PyObject* o);
StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o);
// The following functions copy array data from the buffers underlying Numpy
// ndarrays into those underlying XLA literals, and vice versa.

View File

@ -529,13 +529,13 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
}
ordered_input_dimensions[0] =
lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0));
lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
ordered_input_dimensions[1] =
lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1));
lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
ordered_kernel_dimensions[0] =
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
ordered_kernel_dimensions[1] =
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
std::vector<std::pair<int64, int64>> paddings =
MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
@ -546,7 +546,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
WindowDimension dim;
dim.set_size(
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
dim.set_stride(kernel_stride.first);
dim.set_padding_low(paddings[0].first);
dim.set_padding_high(paddings[0].second);
@ -556,7 +556,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
WindowDimension dim2;
dim2.set_size(
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
dim2.set_stride(kernel_stride.second);
dim2.set_padding_low(paddings[1].first);
dim2.set_padding_high(paddings[1].second);
@ -565,7 +565,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
*window.add_dimensions() = dim2;
const Shape& shape = ShapeInference::InferConvolveShape(
lhs_literal->shape(), rhs_literal->shape(),
lhs_literal.shape(), rhs_literal.shape(),
/*feature_group_count=*/1, window, dnums)
.ConsumeValueOrDie();
@ -585,18 +585,18 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
auto computation = module.AddEntryComputation(b.Build());
HloEvaluator evaluator;
std::unique_ptr<Literal> result_literal =
Literal result_literal =
evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie();
CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4);
auto result =
absl::make_unique<Array4D<float>>(result_literal->shape().dimensions(0),
result_literal->shape().dimensions(1),
result_literal->shape().dimensions(2),
result_literal->shape().dimensions(3));
absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
result_literal.shape().dimensions(1),
result_literal.shape().dimensions(2),
result_literal.shape().dimensions(3));
result->Each([&](absl::Span<const int64> indices, float* value) {
*value = result_literal->Get<float>(indices);
*value = result_literal.Get<float>(indices);
});
return result;

View File

@ -55,7 +55,7 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) {
auto result = ReferenceUtil::TransposeArray2D(*matrix_);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
*actual_literal, ErrorSpec(0.0001));
actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, MatmulArray2D) {
@ -67,14 +67,14 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
*actual_literal, ErrorSpec(0.0001));
actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal,
LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, actual_literal,
ErrorSpec(0.0001));
}
@ -82,7 +82,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal,
LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, actual_literal,
ErrorSpec(0.0001));
}
@ -90,14 +90,14 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) {
auto result = LiteralUtil::CreateR1<float>(ReferenceUtil::Reduce4DTo1D(
Array4D<float>(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2},
[](float a, float b) { return a + b; }));
LiteralTestUtil::ExpectR1Equal<float>({0}, *result);
LiteralTestUtil::ExpectR1Equal<float>({0}, result);
}
TEST_F(ReferenceUtilTest, MapArray2D) {
auto identity = [](float value) { return log(exp(value)); };
auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal,
LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal,
ErrorSpec(0.0001));
}
@ -108,7 +108,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
*actual_literal, ErrorSpec(0.0001));
actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, MapArray4D) {
@ -121,7 +121,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) {
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.FillWithMultiples(2.0f);
LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
ErrorSpec(0.0001));
}
@ -138,7 +138,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.Fill(0.0f);
LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
ErrorSpec(0.0001));
}
@ -146,16 +146,16 @@ TEST_F(ReferenceUtilTest, SliceArray2D) {
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}});
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}},
*actual_literal, ErrorSpec(0.0001));
LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}}, actual_literal,
ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceStridedArray2D) {
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}});
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}},
*actual_literal, ErrorSpec(0.0001));
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}}, actual_literal,
ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceArray3D) {
@ -167,7 +167,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
LiteralTestUtil::ExpectR3Near<float>(
{{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal,
{{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, actual_literal,
ErrorSpec(0.0001));
}
@ -180,8 +180,8 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
LiteralTestUtil::ExpectR3Near<float>(
{{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}},
*actual_literal, ErrorSpec(0.0001));
{{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, actual_literal,
ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceArray4D) {
@ -194,7 +194,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) {
LiteralTestUtil::ExpectR4Near<float>(
{{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}},
*actual_literal, ErrorSpec(0.0001));
actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
@ -208,7 +208,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
LiteralTestUtil::ExpectR4Near<float>(
{{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}},
{{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}},
*actual_literal, ErrorSpec(0.0001));
actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
@ -220,7 +220,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@ -233,7 +233,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@ -268,7 +268,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@ -302,7 +302,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@ -358,7 +358,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@ -411,7 +411,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@ -424,7 +424,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) {
[](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual);
LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}},
*actual_literal, ErrorSpec(0.0001));
actual_literal, ErrorSpec(0.0001));
}
} // namespace

View File

@ -95,12 +95,11 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) {
std::vector<float> expected = {
1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796,
6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};
std::unique_ptr<Literal> expected_literal =
LiteralUtil::CreateR1<float>(expected);
Literal expected_literal = LiteralUtil::CreateR1<float>(expected);
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
computation, {}, nullptr));
EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal,
EXPECT_TRUE(LiteralTestUtil::Near(expected_literal, result_literal,
ErrorSpec(0.0001)));
}

View File

@ -205,7 +205,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
HloInstruction* zero =
computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique()));
LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
return computation_->AddInstruction(HloInstruction::CreateReduce(
@ -527,7 +527,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation,
return computation->AddInstruction(HloInstruction::CreateTuple(elems));
} else {
return computation->AddInstruction(
HloInstruction::CreateConstant(literal.CloneToUnique()));
HloInstruction::CreateConstant(literal.Clone()));
}
}
@ -546,7 +546,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
// If a literal is all the same element replace it with a scalar broadcast.
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
constant->literal().IsAllFirst()) {
std::unique_ptr<Literal> unique_scalar = absl::make_unique<Literal>(
Literal unique_scalar(
LiteralUtil::GetFirstScalarLiteral(constant->literal()));
HloInstruction* scalar = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(unique_scalar)));
@ -676,7 +676,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
return Status::OK();
}
auto inverse = computation_->AddInstruction(
HloInstruction::CreateConstant((new_literal.CloneToUnique())));
HloInstruction::CreateConstant((new_literal.Clone())));
TF_ASSIGN_OR_RETURN(auto new_divide,
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
return ReplaceInstruction(divide, new_divide);
@ -1469,7 +1469,7 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
auto* iota = Cast<HloIotaInstruction>(instruction);
if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique()));
LiteralUtil::Zero(iota->shape().element_type()).Clone()));
return ReplaceWithNewInstruction(
iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
}
@ -1572,7 +1572,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
if (IsAll(rhs, 0)) {
auto one = HloInstruction::CreateConstant(
LiteralUtil::One(power->shape().element_type()).CloneToUnique());
LiteralUtil::One(power->shape().element_type()).Clone());
std::unique_ptr<HloInstruction> ones;
if (ShapeUtil::IsScalar(power->shape())) {
ones = std::move(one);
@ -1607,7 +1607,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) {
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::One(rhs->shape().element_type()).CloneToUnique()));
LiteralUtil::One(rhs->shape().element_type()).Clone()));
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
// broadcast in divide HLO as we are trying to eliminate implicit
@ -2062,7 +2062,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
if (!converted_pad_literal.ok()) {
return false;
}
return *converted_pad_literal.ValueOrDie() == reduce_init_literal;
return converted_pad_literal.ValueOrDie() == reduce_init_literal;
};
// The pad value is usually a constant, so we handle that case and do not
// try to get more fancy about proving equivalence in cases beyond that.
@ -2223,8 +2223,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
HloInstruction::CreateBroadcast(
convolution->shape(),
computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(convolution->shape().element_type())
.CloneToUnique())),
LiteralUtil::Zero(convolution->shape().element_type()))),
{}));
}

View File

@ -2932,9 +2932,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
HloComputation::Builder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
std::unique_ptr<Literal> value = LiteralUtil::MakeTuple(
{LiteralUtil::CreateR0<float>(constant_scalar).get(),
LiteralUtil::CreateR1<float>(constant_vector).get()});
Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
LiteralUtil::CreateR1<float>(constant_vector)};
Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
auto computation = module().AddEntryComputation(builder.Build());

View File

@ -205,11 +205,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
const Shape feature_shape = scale->shape();
auto zero_literal = LiteralUtil::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon = add(HloInstruction::CreateBroadcast(
operand_shape,
add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
@ -331,7 +331,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
const Shape feature_shape = scale->shape();
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
operand_shape,
computation_->AddInstruction(
@ -464,11 +464,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
const int64 elements_per_feature_int64 = size_in_elements / feature_count;
auto zero_literal = LiteralUtil::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon_scalar =
add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
auto epsilon_activation = add(
@ -560,7 +560,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
auto elements_per_feature_literal =
LiteralUtil::CreateR0<float>(elements_per_feature_int64);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
elements_per_feature_literal.Convert(ptype));
auto elements_per_feature = add(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,

View File

@ -163,10 +163,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
EXPECT_TRUE(LiteralTestUtil::Equal(
*LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)),
LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)),
dot->operand(0)->literal()));
EXPECT_TRUE(LiteralTestUtil::Equal(
*LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)),
LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)),
dot->operand(1)->literal()));
}

View File

@ -1245,9 +1245,10 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
// Test that a tuple constant which is forwarded to the computation output
// is properly handled.
auto builder = HloComputation::Builder(TestName());
Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
LiteralUtil::CreateR0<int64>(1)};
builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
LiteralUtil::CreateR0<int64>(1).get()})));
LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());

View File

@ -440,15 +440,15 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
// computation. The buffer containing {0, 1} is copied by GetTupleElement, and
// the buffers containing {3} and 3 are dead.
auto builder = HloComputation::Builder(TestName());
auto inner_tuple0 =
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
LiteralUtil::CreateR0<int64>(1).get()});
auto inner_tuple1 =
LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
Literal elements0[] = {LiteralUtil::CreateR0<int64>(0),
LiteralUtil::CreateR0<int64>(1)};
auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]});
Literal element1 = LiteralUtil::CreateR0<int64>(3);
auto inner_tuple1 = LiteralUtil::MakeTuple({&element1});
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1})));
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
inner_tuple0->shape(), tuple_constant, 0));
inner_tuple0.shape(), tuple_constant, 0));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());

View File

@ -214,8 +214,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
expanded_filter = add(HloInstruction::CreateConcatenate(
expanded_filter_shape, concat_operands, input_feature_dim));
}
auto zero = add(HloInstruction::CreateConstant(absl::make_unique<Literal>(
LiteralUtil::Zero(expanded_filter_shape.element_type()))));
auto zero = add(HloInstruction::CreateConstant(
LiteralUtil::Zero(expanded_filter_shape.element_type())));
auto zero_filter =
add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
auto new_filter = add(

View File

@ -45,7 +45,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
auto builder = HloComputation::Builder(TestName());
auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
Shape vshape = input_literal1->shape();
Shape vshape = input_literal1.shape();
auto input1 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal1)));
@ -78,13 +78,13 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_);
LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, result, error_spec_);
}
TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
Shape vshape = input_literal->shape();
Shape vshape = input_literal.shape();
auto input = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@ -125,8 +125,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result,
error_spec_);
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, result, error_spec_);
}
TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
@ -135,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
Shape vshape = input_literal->shape();
Shape vshape = input_literal.shape();
auto input = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@ -213,7 +212,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
// Check the output correctness.
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
*result, error_spec_);
result, error_spec_);
}
TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
@ -232,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
// each fusion instruction to ensure that negate is not duplicated.
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
Shape vshape = input_literal->shape();
Shape vshape = input_literal.shape();
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));

View File

@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
*LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0minor));
TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
TestInfeedRoundTrip(*LiteralUtil::CreateR4(
TestInfeedRoundTrip(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
TestInfeedRoundTrip(
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
LiteralUtil::CreateR0<bool>(false).get()}));
TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<uint32>({1, 2, 3}),
LiteralUtil::CreateR0<bool>(false)}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
}
// Tests Infeed operation used in a while loop, as in the code below. The
@ -157,21 +157,21 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
// Send 5 Infeed data of shape F32[3].
ASSERT_IS_OK(
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3})));
client_->TransferToInfeed(LiteralUtil::CreateR1<float>({1, 2, 3})));
ASSERT_IS_OK(
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6})));
client_->TransferToInfeed(LiteralUtil::CreateR1<float>({4, 5, 6})));
ASSERT_IS_OK(
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9})));
client_->TransferToInfeed(LiteralUtil::CreateR1<float>({7, 8, 9})));
ASSERT_IS_OK(
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12})));
client_->TransferToInfeed(LiteralUtil::CreateR1<float>({10, 11, 12})));
ASSERT_IS_OK(
client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15})));
client_->TransferToInfeed(LiteralUtil::CreateR1<float>({13, 14, 15})));
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
// Only the first 3 infeed data should be added.
LiteralTestUtil::ExpectR0Near<float>(45.0f, *result_literal, ErrorSpec{1e-7});
LiteralTestUtil::ExpectR0Near<float>(45.0f, result_literal, ErrorSpec{1e-7});
}
// Tests two Infeed operations with a total order. The order is enforced by
@ -250,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Send the first 4 Infeed data of shape Tuple(F32[2], PRED).
ASSERT_IS_OK(client_->TransferToInfeed(
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
LiteralUtil::CreateR0<bool>(true).get()})));
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}),
LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(),
LiteralUtil::CreateR0<bool>(true).get()})));
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({3, 4}),
LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(),
LiteralUtil::CreateR0<bool>(true).get()})));
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({5, 6}),
LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(),
LiteralUtil::CreateR0<bool>(false).get()})));
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8}),
LiteralUtil::CreateR0<bool>(false)})));
// Asynchronously launch the execution on the device.
std::unique_ptr<GlobalData> result;
@ -275,21 +275,21 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED).
sleep(1);
ASSERT_IS_OK(client_->TransferToInfeed(
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
LiteralUtil::CreateR0<bool>(true).get()})));
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2, 3}),
LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(),
LiteralUtil::CreateR0<bool>(false).get()})));
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8, 9}),
LiteralUtil::CreateR0<bool>(false)})));
ASSERT_IS_OK(client_->TransferToInfeed(
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
LiteralUtil::CreateR0<bool>(true).get()})));
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({4, 5, 6}),
LiteralUtil::CreateR0<bool>(true)})));
// Wait for the execution to be done, and transfer the result.
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
// Only the first 6 infeed data should be added.
LiteralTestUtil::ExpectR0Near<float>(66.0f, *result_literal, ErrorSpec{1e-7});
LiteralTestUtil::ExpectR0Near<float>(66.0f, result_literal, ErrorSpec{1e-7});
}
} // namespace

View File

@ -41,8 +41,7 @@ class CpuNoAliasTest : public CpuCodegenTest {};
TEST_F(CpuNoAliasTest, Concat) {
HloComputation::Builder builder(TestName());
std::unique_ptr<Literal> literal =
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
HloInstruction* param_x = builder.AddInstruction(
HloInstruction::CreateParameter(0, param_shape, "x"));

View File

@ -56,9 +56,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
std::unique_ptr<Literal> rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
RunTest(hlo_text, {lhs.get(), rhs.get()});
Literal lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
Literal rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
RunTest(hlo_text, {&lhs, &rhs});
}
} // namespace
} // namespace xla

View File

@ -125,7 +125,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
device_memory.size());
// Element is array-shaped: transfer array data to device buffer.
const auto subliteral = LiteralSlice(literal, index);
std::unique_ptr<Literal> relayed_out_literal;
Literal relayed_out_literal;
const void* source;
if (LayoutUtil::Equal(device_subshape.layout(),
subliteral.shape().layout())) {
@ -138,7 +138,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
// Relayout data before transferring.
relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
/*shape_index=*/{});
source = relayed_out_literal->untyped_data();
source = relayed_out_literal.untyped_data();
TF_RETURN_IF_ERROR(TransferBufferToDevice(
stream,
/*size=*/GetByteSizeRequirement(device_subshape), source,

View File

@ -590,7 +590,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) {
Array4D<float> constant_arr(4, 4, 2, 2);
constant_arr.FillIota(0);
string constant_str =
LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString();
LiteralUtil::CreateR4FromArray4D(constant_arr).ToString();
ParseAndVerifyModule(absl::StrFormat(R"(
HloModule test

View File

@ -23,7 +23,6 @@ limitations under the License.
namespace xla {
namespace gpu {
// We want the input/output feature counts of an f16 conv to be factors of 8,
// because without this cudnn can't use tensor cores on the conv.
static constexpr int64 kDesiredNumFeaturesFactor = 8;
@ -63,8 +62,8 @@ static HloInstruction* PadInstruction(HloInstruction* instr,
HloComputation* comp = instr->parent();
const Shape& shape = instr->shape();
auto* zero = comp->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(shape.element_type()).CloneToUnique()));
auto* zero = comp->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape));

View File

@ -68,9 +68,8 @@ HloInstruction* MaybePaddedAndSlicedInput(
conv_window.dimensions(i).base_dilation() - 1);
}
PrimitiveType element_type = input->shape().element_type();
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
HloInstruction* padding = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
input = MakePadHlo(input, padding, padding_config).ValueOrDie();
}
@ -125,9 +124,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
HloComputation* computation = kernel->parent();
PrimitiveType element_type = kernel->shape().element_type();
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
HloInstruction* padding = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
}
} // namespace
@ -236,9 +234,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// Create a new backward convolution replacing the old one.
HloComputation* computation = backward_conv->parent();
HloInstruction* output = backward_conv->mutable_operand(1);
HloInstruction* padding = computation->AddInstruction(
HloInstruction::CreateConstant(absl::make_unique<Literal>(
LiteralUtil::Zero(input->shape().element_type()))));
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(input->shape().element_type())));
HloInstruction* padded_input =
MakePadHlo(input, padding, input_padding_config).ValueOrDie();

View File

@ -38,8 +38,7 @@ class GpuCopyTest : public GpuCodegenTest {};
TEST_F(GpuCopyTest, UseMemcpy) {
HloComputation::Builder builder(TestName());
std::unique_ptr<Literal> literal =
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
builder.AddInstruction(HloInstruction::CreateUnary(

View File

@ -53,40 +53,40 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
*LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0minor));
TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
TestInfeedRoundTrip(*LiteralUtil::CreateR4(
TestInfeedRoundTrip(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
@ -95,26 +95,26 @@ TEST_F(InfeedTest, SingleInfeedR4S32) {
TEST_F(InfeedTest, LargeInfeed) {
Array4D<float> array(80, 100, 8, 128);
array.FillIota(1.0f);
TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D<float>(array));
TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D<float>(array));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
TestInfeedRoundTrip(
*LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
LiteralUtil::CreateR0<bool>(false).get()}));
TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<uint32>({1, 2, 3}),
LiteralUtil::CreateR0<bool>(false)}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
}
// Tests that a large tuple infeed can be handled.
TEST_F(InfeedTest, SingleInfeedLargeTuple) {
Array4D<float> array(40, 100, 8, 128);
array.FillIota(1.0f);
TestInfeedRoundTrip(*LiteralUtil::MakeTuple(
{LiteralUtil::CreateR4FromArray4D<float>(array).get(),
LiteralUtil::CreateR0<int32>(5).get()}));
TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4FromArray4D<float>(array),
LiteralUtil::CreateR0<int32>(5)}));
}
} // namespace

View File

@ -76,10 +76,10 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
continue;
}
std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
Literal result;
// Currently we skip unimplemented operations.
// TODO(b/35975797): Fold constant computations for more operations.
if (result == nullptr) {
if (!evaluator->TryEvaluate(instruction, &result)) {
VLOG(2) << "Constant folding failed for instruction: "
<< instruction->ToString();
continue;

View File

@ -175,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = literal->Literal::CloneToUnique();
auto literal_clone = literal.Clone();
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
@ -198,7 +198,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
root->literal().EachCell<NativeT>(
[&](absl::Span<const int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
matched = matched && (value == literal_clone->Get<NativeT>(rindexes));
matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
});
EXPECT_TRUE(matched);
}

View File

@ -321,18 +321,17 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
padding_config_dim.set_edge_padding_high(zeros_to_append);
*padding_config.add_dimensions() = padding_config_dim;
HloInstruction* zero = computation->AddInstruction(
HloInstruction::CreateConstant(absl::make_unique<Literal>(
LiteralUtil::Zero(operand->shape().element_type()))));
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(operand->shape().element_type())));
return MakePadHlo(operand, zero, padding_config);
}
StatusOr<HloInstruction*> BroadcastZeros(
HloComputation* computation, PrimitiveType element_type,
absl::Span<const int64> broadcast_dimensions) {
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
HloInstruction* zero = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
/*result_shape_bounds=*/broadcast_dimensions);
}

View File

@ -57,10 +57,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
entry_computation->set_root_instruction(first_1_dims_collapsed);
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({3, 4}));
CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({3, 4}));
}
TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
@ -78,13 +78,13 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
Literal result_literal,
evaluator.Evaluate<Literal>(
*module,
{LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})}));
CHECK_EQ(*result_literal,
*LiteralUtil::CreateR2<int32>(
CHECK_EQ(result_literal,
LiteralUtil::CreateR2<int32>(
{{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
}
@ -103,10 +103,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
*module, {LiteralUtil::CreateR1<int32>({9, 10})}));
CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9, 10}}));
Literal result_literal,
evaluator.Evaluate<Literal>(*module,
{LiteralUtil::CreateR1<int32>({9, 10})}));
CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9, 10}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
@ -124,10 +124,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
*module, {LiteralUtil::CreateR1<int32>({9, 10})}));
CHECK_EQ(*result_literal, *LiteralUtil::CreateR3<int32>({{{9, 10}}}));
Literal result_literal,
evaluator.Evaluate<Literal>(*module,
{LiteralUtil::CreateR1<int32>({9, 10})}));
CHECK_EQ(result_literal, LiteralUtil::CreateR3<int32>({{{9, 10}}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
@ -144,10 +144,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
*module, {LiteralUtil::CreateR0<int32>(9)}));
CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9}}));
TF_ASSERT_OK_AND_ASSIGN(
Literal result_literal,
evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(9)}));
CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9}}));
}
TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
@ -165,11 +165,11 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
Literal result_literal,
evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
CHECK_EQ(*result_literal,
*LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
CHECK_EQ(result_literal,
LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
}
TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
@ -187,10 +187,10 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
entry_computation->set_root_instruction(zero_padded_param);
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
@ -208,10 +208,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
entry_computation->set_root_instruction(zeros);
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
*module, {LiteralUtil::CreateR0<int32>(0)}));
CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
TF_ASSERT_OK_AND_ASSIGN(
Literal result_literal,
evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(0)}));
CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
@ -229,11 +229,11 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
entry_computation->set_root_instruction(zeros);
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR0<float>(0.0f)}));
CHECK_EQ(*result_literal,
*LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
CHECK_EQ(result_literal,
LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
}
} // namespace

View File

@ -73,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR0<float>(84.0);
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
@ -105,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
@ -135,7 +135,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, ConstantsSameValueDifferentType) {

View File

@ -54,9 +54,8 @@ namespace xla {
namespace {
template <typename OperandT>
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
LiteralSlice lhs_literal,
LiteralSlice rhs_literal) {
StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
std::function<bool(OperandT, OperandT)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
@ -94,9 +93,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
<< HloOpcodeString(opcode);
}
auto result = absl::make_unique<Literal>(shape);
Literal result(shape);
TF_RETURN_IF_ERROR(
result->Populate<bool>([&](absl::Span<const int64> multi_index) {
result.Populate<bool>([&](absl::Span<const int64> multi_index) {
return compare_op(lhs_literal.Get<OperandT>(multi_index),
rhs_literal.Get<OperandT>(multi_index));
}));
@ -105,9 +104,9 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
}
template <>
StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal,
LiteralSlice rhs_literal) {
StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
LiteralSlice lhs_literal,
LiteralSlice rhs_literal) {
std::function<bool(complex64, complex64)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
@ -125,9 +124,9 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
<< HloOpcodeString(opcode);
}
auto result = absl::make_unique<Literal>(shape);
Literal result(shape);
TF_RETURN_IF_ERROR(
result->Populate<bool>([&](absl::Span<const int64> multi_index) {
result.Populate<bool>([&](absl::Span<const int64> multi_index) {
return compare_op(lhs_literal.Get<complex64>(multi_index),
rhs_literal.Get<complex64>(multi_index));
}));
@ -193,7 +192,7 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations)
}
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
StatusOr<Literal> HloEvaluator::Evaluate(
const HloModule& module, absl::Span<const LiteralPtr> arg_literals) {
XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
@ -206,11 +205,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())
.CloneToUnique();
.Clone();
}
template <>
StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
const HloModule& module, absl::Span<const Literal> arg_literals) {
std::vector<const Literal*> arg_literal_ptrs;
for (const auto& literal_ptr : arg_literals) {
arg_literal_ptrs.push_back(&literal_ptr);
}
return Evaluate<const Literal*>(module, arg_literal_ptrs);
}
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
StatusOr<Literal> HloEvaluator::Evaluate(
const HloComputation& computation,
absl::Span<const LiteralPtr> arg_literals) {
CHECK(computation.parent() != nullptr);
@ -224,11 +233,21 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
}
TF_RETURN_IF_ERROR(computation.Accept(this));
return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique();
return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
}
template <>
StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
const HloComputation& computation, absl::Span<const Literal> arg_literals) {
std::vector<const Literal*> arg_literal_ptrs;
for (const auto& literal_ptr : arg_literals) {
arg_literal_ptrs.push_back(&literal_ptr);
}
return Evaluate<const Literal*>(computation, arg_literal_ptrs);
}
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
StatusOr<Literal> HloEvaluator::Evaluate(
HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals) {
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
@ -247,18 +266,27 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
<< input_literal->ToString();
TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
evaluated_[operand] = input_literal->CloneToUnique();
evaluated_[operand] = input_literal->Clone();
}
}
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
return GetEvaluatedLiteralFor(instruction).CloneToUnique();
return GetEvaluatedLiteralFor(instruction).Clone();
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
HloInstruction* instruction) {
template <>
StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
HloInstruction* instruction, absl::Span<const Literal> arg_literals) {
std::vector<const Literal*> arg_literal_ptrs;
for (const auto& literal : arg_literals) {
arg_literal_ptrs.push_back(&literal);
}
return Evaluate<const Literal*>(instruction, arg_literal_ptrs);
}
StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
if (instruction->opcode() == HloOpcode::kParameter) {
return tensorflow::errors::FailedPrecondition(
"Cannot evaluate a parameter.");
@ -274,21 +302,22 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
return GetEvaluatedLiteralFor(instruction).CloneToUnique();
return GetEvaluatedLiteralFor(instruction).Clone();
}
std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
HloInstruction* instruction) {
bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
CHECK(result != nullptr);
auto result_or = Evaluate(instruction);
if (!result_or.ok()) {
VLOG(1) << "TryEvaluate failed:" << result_or.status();
return nullptr;
return false;
}
return result_or.ConsumeValueOrDie();
*result = result_or.ConsumeValueOrDie();
return true;
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
const HloInstruction* instruction,
const std::unordered_map<const HloInstruction*, const Literal*>&
substitutions) {
@ -299,7 +328,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
owned_operands.push_back(operand->Clone());
} else {
owned_operands.push_back(
HloInstruction::CreateConstant(it->second->CloneToUnique()));
HloInstruction::CreateConstant(it->second->Clone()));
}
}
@ -316,12 +345,12 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
return result;
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
HloInstruction::CreateConstant(lhs.CloneToUnique());
HloInstruction::CreateConstant(lhs.Clone());
std::unique_ptr<HloInstruction> rhs_instr =
HloInstruction::CreateConstant(rhs.CloneToUnique());
HloInstruction::CreateConstant(rhs.Clone());
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
@ -331,10 +360,10 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
return result;
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
HloOpcode opcode, const Literal& operand) {
std::unique_ptr<HloInstruction> operand_instr =
HloInstruction::CreateConstant(operand.CloneToUnique());
HloInstruction::CreateConstant(operand.Clone());
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
@ -343,14 +372,14 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
return result;
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
StatusOr<Literal> HloEvaluator::EvaluateDotOp(
const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config, const Literal& lhs,
const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
HloInstruction::CreateConstant(lhs.CloneToUnique());
HloInstruction::CreateConstant(lhs.Clone());
std::unique_ptr<HloInstruction> rhs_instr =
HloInstruction::CreateConstant(rhs.CloneToUnique());
HloInstruction::CreateConstant(rhs.Clone());
TF_ASSIGN_OR_RETURN(
Shape dot_shape,
@ -371,7 +400,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
<< ", but input literal shape is: "
<< ShapeUtil::HumanString(input_literal->shape());
evaluated_[parameter] = input_literal->CloneToUnique();
evaluated_[parameter] = input_literal->Clone();
return Status::OK();
}
@ -421,7 +450,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
AsInt64Slice(operand_shape.dimensions())));
dest_indices[concat_dim] +=
@ -824,7 +853,7 @@ class OutputOffsetIndexToInputIndex {
// there is one) to `reshaped_start_indices`.
static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
int64 index_vector_dim, const Literal& start_indices,
std::unique_ptr<Literal>* reshaped_start_indices) {
Literal* reshaped_start_indices) {
if (start_indices.shape().dimensions_size() != index_vector_dim) {
return std::cref(start_indices);
}
@ -834,16 +863,16 @@ static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
new_shape.push_back(1);
TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
start_indices.Reshape(new_shape));
return std::cref(**reshaped_start_indices);
return std::cref(*reshaped_start_indices);
}
Status HloEvaluator::HandleGather(HloInstruction* gather) {
std::unique_ptr<Literal> result = Literal::CreateFromShape(gather->shape());
Literal result = Literal::CreateFromShape(gather->shape());
const Shape& shape = gather->shape();
const GatherDimensionNumbers& dim_numbers =
gather->gather_dimension_numbers();
const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
std::unique_ptr<Literal> reshaped_start_indices;
Literal reshaped_start_indices;
TF_ASSIGN_OR_RETURN(
const Literal& start_indices,
ReshapedGatherIndices(dim_numbers.index_vector_dim(),
@ -908,7 +937,7 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
DCHECK_LT(input_index[i], operand_shape.dimensions(i));
}
TF_RETURN_IF_ERROR(
result->CopyElementFrom(operand, input_index, output_index));
result.CopyElementFrom(operand, input_index, output_index));
return true;
};
@ -977,18 +1006,16 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
evaluated_[get_tuple_element] = absl::make_unique<Literal>(
ShapeUtil::GetTupleElementShape(operand->shape(), index));
return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
/*dest_shape_index=*/{},
/*src_shape_index=*/{index});
evaluated_[get_tuple_element] =
Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
/*dest_shape_index=*/{},
/*src_shape_index=*/{index});
}
Status HloEvaluator::HandleCopy(HloInstruction* copy) {
TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique();
evaluated_[copy] = std::move(result);
evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
return Status::OK();
}
@ -1004,7 +1031,7 @@ Status HloEvaluator::HandleCall(HloInstruction* call) {
}
HloEvaluator embedded_evaluator;
std::unique_ptr<Literal> result =
Literal result =
embedded_evaluator.Evaluate<const Literal*>(*computation, arg_literals)
.ConsumeValueOrDie();
@ -1036,7 +1063,7 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
}
HloEvaluator embedded_evaluator;
std::unique_ptr<Literal> result =
Literal result =
embedded_evaluator
.Evaluate<const Literal*>(*readded_computation, arg_literals)
.ConsumeValueOrDie();
@ -1056,7 +1083,7 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
auto* false_computation = conditional->false_computation();
HloEvaluator embedded_evaluator;
std::unique_ptr<Literal> result;
Literal result;
if (pred.Get<bool>({})) {
result = embedded_evaluator
.Evaluate<const Literal*>(*true_computation,
@ -1081,9 +1108,9 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) {
// If predicate is of scalar type, no element-wise selection would be needed.
if (ShapeUtil::IsScalar(pred.shape())) {
if (pred.Get<bool>({})) {
evaluated_[select] = on_true.CloneToUnique();
evaluated_[select] = on_true.Clone();
} else {
evaluated_[select] = on_false.CloneToUnique();
evaluated_[select] = on_false.Clone();
}
return Status::OK();
}
@ -1097,9 +1124,9 @@ Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
if (pred.Get<bool>({})) {
evaluated_[tuple_select] = on_true.CloneToUnique();
evaluated_[tuple_select] = on_true.Clone();
} else {
evaluated_[tuple_select] = on_false.CloneToUnique();
evaluated_[tuple_select] = on_false.Clone();
}
return Status::OK();
}
@ -1108,7 +1135,7 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
HloComputation* cond_comp = while_hlo->while_condition();
HloComputation* body_comp = while_hlo->while_body();
// Initialize the loop carried valued with the input to the While instruction.
auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique();
auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
bool keep_going = true;
int64 iteration_count = 0;
HloEvaluator cond_evaluator(max_loop_iterations_);
@ -1118,13 +1145,13 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
while_hlo->name(), max_loop_iterations_);
}
TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate<Literal*>(
*cond_comp, {lcv.get()}));
keep_going = cond_val->GetFirstElement<bool>();
TF_ASSIGN_OR_RETURN(auto cond_val,
cond_evaluator.Evaluate<Literal*>(*cond_comp, {&lcv}));
keep_going = cond_val.GetFirstElement<bool>();
if (keep_going) {
TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate<Literal*>(
*body_comp, {lcv.get()}));
VLOG(3) << "Loop iteration result: " << body_val->ToString();
*body_comp, {&lcv}));
VLOG(3) << "Loop iteration result: " << body_val.ToString();
lcv = std::move(body_val);
cond_evaluator.ResetVisitStates();
loop_body_evaluator.ResetVisitStates();
@ -1139,9 +1166,9 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
// hoops to make this work.
namespace {
template <typename KeyType, typename ValueType>
StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
HloInstruction* sort, const Literal& keys_literal,
const Literal& values_literal) {
StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
const Literal& keys_literal,
const Literal& values_literal) {
auto rank = ShapeUtil::Rank(keys_literal.shape());
TF_RET_CHECK(
ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
@ -1179,57 +1206,55 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
result_keys.push_back(key_value.first);
result_values.push_back(key_value.second);
}
auto result_keys_literal = absl::make_unique<Literal>(keys_literal.shape());
result_keys_literal->PopulateR1(absl::Span<const KeyType>(result_keys));
auto result_values_literal =
absl::make_unique<Literal>(values_literal.shape());
result_values_literal->PopulateR1(
Literal result_keys_literal(keys_literal.shape());
result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys));
Literal result_values_literal(values_literal.shape());
result_values_literal.PopulateR1(
absl::Span<const ValueType>(result_values));
return std::make_pair(std::move(result_keys_literal),
std::move(result_values_literal));
};
std::unique_ptr<Literal> result_tuple;
Literal result_tuple;
if (rank == 1) {
auto result_pair = sort_r1(keys_literal, values_literal);
result_tuple = LiteralUtil::MakeTuple(
{result_pair.first.get(), result_pair.second.get()});
result_tuple =
LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second});
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
auto keys_result_literal = absl::make_unique<Literal>(keys_literal.shape());
auto values_result_literal =
absl::make_unique<Literal>(values_literal.shape());
Literal keys_result_literal(keys_literal.shape());
Literal values_result_literal(values_literal.shape());
int64 r1_length = keys_literal.shape().dimensions(1);
for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
keys_literal.Slice({row, 0}, {row + 1, r1_length})
->Reshape({r1_length}));
.Reshape({r1_length}));
TF_ASSIGN_OR_RETURN(auto values_r1_slice,
values_literal.Slice({row, 0}, {row + 1, r1_length})
->Reshape({r1_length}));
auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice);
.Reshape({r1_length}));
auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice);
TF_ASSIGN_OR_RETURN(auto sorted_keys,
r1_result_pair.first->Reshape({1, r1_length}));
r1_result_pair.first.Reshape({1, r1_length}));
TF_ASSIGN_OR_RETURN(auto sorted_values,
r1_result_pair.second->Reshape({1, r1_length}));
TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom(
*sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom(
*sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
r1_result_pair.second.Reshape({1, r1_length}));
TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
}
result_tuple = LiteralUtil::MakeTuple(
{keys_result_literal.get(), values_result_literal.get()});
result_tuple =
LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
}
VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString();
VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
return std::move(result_tuple);
}
template <typename KeyType>
StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
HloInstruction* sort, const Literal& keys_literal,
const Literal& values_literal) {
StatusOr<Literal> EvaluateSortCurried(HloInstruction* sort,
const Literal& keys_literal,
const Literal& values_literal) {
switch (sort->operand(1)->shape().element_type()) {
case F32:
return EvaluateSortInternal<KeyType, float>(sort, keys_literal,
@ -1248,9 +1273,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
}
}
StatusOr<std::unique_ptr<Literal>> EvaluateSort(HloInstruction* sort,
const Literal& keys_literal,
const Literal& values_literal) {
StatusOr<Literal> EvaluateSort(HloInstruction* sort,
const Literal& keys_literal,
const Literal& values_literal) {
switch (sort->operand(0)->shape().element_type()) {
case F32:
return EvaluateSortCurried<float>(sort, keys_literal, values_literal);
@ -1319,28 +1344,14 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) {
// Explicit instantiation of templatized Evaluate* methods.
//
template StatusOr<std::unique_ptr<Literal>>
HloEvaluator::Evaluate<const Literal*>(
template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
const HloModule& module, absl::Span<const Literal* const> arg_literals);
template StatusOr<std::unique_ptr<Literal>>
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
const HloModule& module,
absl::Span<const std::unique_ptr<Literal>> arg_literals);
template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
const Literal*>(const HloComputation& computation,
absl::Span<const Literal* const> arg_literals);
template StatusOr<std::unique_ptr<Literal>>
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
const HloComputation& computation,
absl::Span<const std::unique_ptr<Literal>> arg_literals);
absl::Span<const Literal* const> arg_literals);
template StatusOr<std::unique_ptr<Literal>>
HloEvaluator::Evaluate<const Literal*>(
template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
HloInstruction* instruction, absl::Span<const Literal* const> arg_literals);
template StatusOr<std::unique_ptr<Literal>>
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
HloInstruction* instruction,
absl::Span<const std::unique_ptr<Literal>> arg_literals);
} // namespace xla

View File

@ -47,11 +47,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// Precondition: The indices of arg_literals correspond to the parameter
// numbers of the HLO parameters in the computation. See comment below for an
// example.
// `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
// `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> Evaluate(
const HloModule& module, absl::Span<const LiteralPtr> arg_literals);
StatusOr<Literal> Evaluate(const HloModule& module,
absl::Span<const LiteralPtr> arg_literals);
// Evaluates an HLO computation and an array of pointers to literals.
// Returns the evaluated result as a literal if successful.
@ -69,12 +69,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
// 1 in this computation. The input literals array will then have its first
// literal map to Parameter0 and the second map to Parameter1.
// `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
// `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> Evaluate(
const HloComputation& computation,
absl::Span<const LiteralPtr> arg_literals);
StatusOr<Literal> Evaluate(const HloComputation& computation,
absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction and an array of pointers to literals.
// Return the evaluated result as literal if successful.
@ -82,42 +81,43 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// 1. argument literals correspond to the input instruction's parameters in
// their post-ordering.
// 2. the instruction's operands must be of either Parameter or Constant type.
// `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
// `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> Evaluate(
HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals);
StatusOr<Literal> Evaluate(HloInstruction* instruction,
absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction with constant operands.
// Returns the evaluated result as literal if successful.
// Precondition:
// 1. all operands of the input instruction are constants.
// 2. the instruction is not a Parameter operation.
StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
StatusOr<Literal> Evaluate(HloInstruction* instruction);
// Same as Evaluate, except returning nullptr on error.
std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
// Same as Evaluate, except returning false on error and accepts an output
// pointer.
bool TryEvaluate(HloInstruction* instruction, Literal* result);
// Evaluates a single HLO instruction, substituting the given literals for
// some of the instruction's operands.
//
// For example, given instruction = op(A, B, C) and the map
// {A = x, C = y}, this evaluates op(x, B, y).
StatusOr<std::unique_ptr<Literal>> EvaluateWithSubstitutions(
StatusOr<Literal> EvaluateWithSubstitutions(
const HloInstruction* instruction,
const std::unordered_map<const HloInstruction*, const Literal*>&
substitutions);
StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseBinaryOp(
HloOpcode opcode, const Literal& lhs, const Literal& rhs);
StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode,
const Literal& lhs,
const Literal& rhs);
StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp(
HloOpcode opcode, const Literal& operand);
StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
const Literal& operand);
StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config, const Literal& lhs,
const Literal& rhs);
StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config,
const Literal& lhs, const Literal& rhs);
protected:
// Make HloEvaluatorTypedVisitor a friend because it is logically part of this
@ -197,7 +197,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
auto it = evaluated_.find(hlo);
CHECK(it != evaluated_.end())
<< "could not find evaluated value for: " << hlo->ToString();
return *(it->second);
return it->second;
}
// Tracks the HLO instruction and its evaluated literal result.
@ -205,12 +205,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// that are no longer a parent for any other subsequent instruction in
// post-orderring.
// Must be cleared for each evaluation.
tensorflow::gtl::FlatMap<const HloInstruction*, std::unique_ptr<Literal>>
evaluated_;
// Storing Literal in place require the container to have pointer stability so
// we cannot use FlatMap any more.
std::unordered_map<const HloInstruction*, Literal> evaluated_;
private:
template <typename ReturnT, typename NativeT>
static StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
static StatusOr<Literal> ElementWiseUnaryOpImpl(
HloInstruction* instruction,
const std::function<ReturnT(NativeT)>& unary_op,
const Literal& operand_literal) {
@ -227,9 +228,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
ShapeUtil::HumanString(operand->shape()));
}
auto result = absl::make_unique<Literal>(shape);
Literal result(shape);
TF_RETURN_IF_ERROR(
result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return unary_op(operand_literal.Get<NativeT>(multi_index));
}));
return std::move(result);

File diff suppressed because it is too large Load Diff

View File

@ -246,15 +246,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).Convert(
convert->shape().element_type()));
if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) {
parent_->evaluated_[convert] = std::move(result);
} else {
parent_->evaluated_[convert] =
result->Relayout(convert->shape().layout());
parent_->evaluated_[convert] = result.Relayout(convert->shape().layout());
}
return Status::OK();
}
@ -262,15 +261,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleBitcastConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
convert->shape().element_type()));
if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) {
parent_->evaluated_[convert] = std::move(result);
} else {
parent_->evaluated_[convert] =
result->Relayout(convert->shape().layout());
parent_->evaluated_[convert] = result.Relayout(convert->shape().layout());
}
return Status::OK();
}
@ -978,10 +976,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
auto result = absl::make_unique<Literal>(result_shape);
Literal result(result_shape);
TF_RETURN_IF_ERROR(
result->Populate<ReturnT>([&](absl::Span<const int64> out_index) {
result.Populate<ReturnT>([&](absl::Span<const int64> out_index) {
std::vector<int64> from_index(out_index.begin(), out_index.end());
for (const int64 dim : reverse_dimensions) {
from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
@ -1157,8 +1155,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return static_cast<ReturnT>(result_val);
};
auto result = absl::make_unique<Literal>(result_shape);
TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
Literal result(result_shape);
TF_RETURN_IF_ERROR(result.PopulateParallel<ReturnT>(func));
parent_->evaluated_[conv] = std::move(result);
return Status::OK();
@ -1231,9 +1229,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
}
auto result = absl::make_unique<Literal>(dot->shape());
Literal result(dot->shape());
TF_RETURN_IF_ERROR(
result->Populate<ReturnT>([&](absl::Span<const int64> result_index) {
result.Populate<ReturnT>([&](absl::Span<const int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
for (int64 i = 0; i < result_index.size(); i++) {
@ -1280,8 +1278,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Create new HLO of padded shape with padding value.
ReturnT scalar =
parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
auto result = absl::make_unique<Literal>(pad->shape());
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
Literal result(pad->shape());
TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
[&scalar](absl::Span<const int64> multi_index) { return scalar; }));
const Literal& evaluated_operand =
@ -1289,7 +1287,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
0);
std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0);
std::vector<int64> target_index(ShapeUtil::Rank(result.shape()), 0);
// Loop through each element of the operand, assign them to the
// corresponding index of the resulting padded literal.
@ -1311,8 +1309,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return true;
}
}
result->Set<ReturnT>(target_index,
evaluated_operand.Get<ReturnT>(input_index));
result.Set<ReturnT>(target_index,
evaluated_operand.Get<ReturnT>(input_index));
return true;
};
@ -1439,16 +1437,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename NativeT>
StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) {
StatusOr<Literal> MapImpl(HloInstruction* map) {
auto operands = map->operands();
HloComputation* computation = map->to_apply();
auto result = absl::make_unique<Literal>(map->shape());
Literal result(map->shape());
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
TF_RETURN_IF_ERROR(
result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
std::vector<std::unique_ptr<Literal>> arg_literals;
result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
std::vector<Literal> arg_literals;
arg_literals.reserve(operands.size());
// Construct scalar literal parameters to be passed to the map
@ -1463,16 +1461,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
arg_literals.push_back(std::move(curr_val_literal));
}
std::unique_ptr<Literal> computed_result =
embedded_evaluator
.Evaluate<std::unique_ptr<Literal>>(*computation,
arg_literals)
Literal computed_result =
embedded_evaluator.Evaluate<Literal>(*computation, arg_literals)
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again on
// the same computation.
embedded_evaluator.ResetVisitStates();
return computed_result->Get<ReturnT>({});
return computed_result.Get<ReturnT>({});
}));
return std::move(result);
}
@ -1557,9 +1553,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
[](const ReturnT& a, const ReturnT& b) {
return SafeLess<ReturnT>(a, b);
});
auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
result_literal->PopulateR1(absl::Span<const ReturnT>(result_data));
VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
Literal result_literal(keys_literal.shape());
result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
return result_literal;
};
@ -1568,16 +1564,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
Literal result_literal(keys_literal.shape());
int64 r1_length = keys->shape().dimensions(1);
for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto r1_slice,
keys_literal.Slice({row, 0}, {row + 1, r1_length})
->Reshape({r1_length}));
auto r1_result = sort_r1(*r1_slice);
TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length}));
TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
*r1_result, {0, 0}, {row, 0}, {1, r1_length}));
.Reshape({r1_length}));
auto r1_result = sort_r1(r1_slice);
TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
r1_result, {0, 0}, {row, 0}, {1, r1_length}));
}
parent_->evaluated_[sort] = std::move(result_literal);
}
@ -1651,9 +1647,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
absl::InlinedVector<std::unique_ptr<Literal>, 1> results(num_args);
absl::InlinedVector<Literal, 1> results(num_args);
for (int64 i = 0; i < num_args; ++i) {
results[i] = absl::make_unique<Literal>(result_shape);
results[i] = Literal(result_shape);
}
Status eval_status;
@ -1667,7 +1663,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
for (int64 input = 0; input < num_args; ++input) {
TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>(
TF_RETURN_IF_ERROR(results[input].Populate<ReturnT>(
[&](absl::Span<const int64> multi_index) {
if (!eval_status.ok()) {
return init_scalars[input];
@ -1703,8 +1699,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
// Evaluate computation with specified literal operands.
absl::InlinedVector<std::unique_ptr<Literal>, 1>
embedded_operands;
absl::InlinedVector<Literal, 1> embedded_operands;
for (ReturnT value : result_values) {
embedded_operands.push_back(
LiteralUtil::CreateR0<ReturnT>(value));
@ -1717,11 +1712,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
embedded_operands.size());
std::transform(embedded_operands.begin(), embedded_operands.end(),
embedded_operands_ptrs.begin(),
[](const std::unique_ptr<Literal>& ptr) {
return ptr.get();
});
[](Literal& literal) { return &literal; });
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
TF_ASSIGN_OR_RETURN(Literal computed_result,
embedded_evaluator.Evaluate<const Literal*>(
*function, embedded_operands_ptrs));
// Clear visit states so that we can use the evaluator again on
@ -1729,10 +1722,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
embedded_evaluator.ResetVisitStates();
// Assign computed result to result_val.
if (!has_tuple_output) {
result_values[0] = computed_result->Get<ReturnT>({});
result_values[0] = computed_result.Get<ReturnT>({});
} else {
for (int64 i = 0; i < num_args; ++i) {
result_values[i] = computed_result->Get<ReturnT>(
result_values[i] = computed_result.Get<ReturnT>(
/*multi_index=*/{}, /*shape_index=*/{i});
}
}
@ -1748,9 +1741,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (!has_tuple_output) {
parent_->evaluated_[reduce] = std::move(results[0]);
} else {
auto tuple_result = absl::make_unique<Literal>(reduce->shape());
Literal tuple_result(reduce->shape());
for (int64 i = 0; i < num_args; ++i) {
TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i}));
TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
}
parent_->evaluated_[reduce] = std::move(tuple_result);
}
@ -1781,10 +1774,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
auto init_scalar = init_literal.Get<ReturnT>({});
auto result = absl::make_unique<Literal>(select_and_scatter->shape());
Literal result(select_and_scatter->shape());
// Initialize result array with the init value.
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
[&](absl::Span<const int64> output_index) { return init_scalar; }));
std::vector<int64> window_dimension_sizes;
@ -1834,15 +1827,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
selected_val = curr_val;
selected_index = operand_index;
}
curr_val_literal->Set({}, curr_val);
selected_val_literal->Set({}, *selected_val);
std::unique_ptr<Literal> computed_result =
curr_val_literal.Set({}, curr_val);
selected_val_literal.Set({}, *selected_val);
Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
*select,
{selected_val_literal.get(), curr_val_literal.get()})
*select, {&selected_val_literal, &curr_val_literal})
.ConsumeValueOrDie();
bool selected = !computed_result->Get<bool>({});
bool selected = !computed_result.Get<bool>({});
if (selected) {
selected_val = curr_val;
selected_index = operand_index;
@ -1856,16 +1848,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (std::equal(operand_index.begin(), operand_index.end(),
selected_index->begin())) {
auto source = source_literal.Get<ReturnT>(source_index);
auto scattered = result->Get<ReturnT>(operand_index);
source_literal_scatter->Set({}, source);
scattered_literal->Set({}, scattered);
std::unique_ptr<Literal> computed_result =
auto scattered = result.Get<ReturnT>(operand_index);
source_literal_scatter.Set({}, source);
scattered_literal.Set({}, scattered);
Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(*scatter,
{source_literal_scatter.get(),
scattered_literal.get()})
.Evaluate<const Literal*>(
*scatter,
{&source_literal_scatter, &scattered_literal})
.ConsumeValueOrDie();
result->Set(operand_index, computed_result->Get<ReturnT>({}));
result.Set(operand_index, computed_result.Get<ReturnT>({}));
// Clear visit states so that the we can use the evaluator again
// on the same computation.
embedded_evaluator.ResetVisitStates();
@ -1916,10 +1908,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
auto result = absl::make_unique<Literal>(reduce_window->shape());
Literal result(reduce_window->shape());
// For each resulting dimension, calculate and assign computed value.
TF_RETURN_IF_ERROR(
result->Populate<ReturnT>([&](absl::Span<const int64> output_index) {
result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
ReturnT result_val = init_scalar;
std::fill(window_index.begin(), window_index.end(), 0);
@ -1935,18 +1927,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
LiteralUtil::CreateR0<ReturnT>(curr_val);
const auto result_val_literal =
LiteralUtil::CreateR0<ReturnT>(result_val);
std::unique_ptr<Literal> computed_result =
Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
*function,
{result_val_literal.get(), curr_val_literal.get()})
*function, {&result_val_literal, &curr_val_literal})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again
// on the same computation.
embedded_evaluator.ResetVisitStates();
result_val = computed_result->Get<ReturnT>({});
result_val = computed_result.Get<ReturnT>({});
});
return result_val;
@ -1961,7 +1952,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// literal (if there is one) to `reshaped_indices`.
StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
int64 index_vector_dim, const Literal& indices,
std::unique_ptr<Literal>* reshaped_indices) {
Literal* reshaped_indices) {
if (indices.shape().dimensions_size() != index_vector_dim) {
return std::cref(indices);
}
@ -1970,7 +1961,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
indices.shape().dimensions().end());
new_shape.push_back(1);
TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape));
return std::cref(**reshaped_indices);
return std::cref(*reshaped_indices);
}
// Returns an ShapeUtil::IndexIterationSpace that iterates over the update
@ -2230,7 +2221,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
scatter->scatter_dimension_numbers();
const Literal& operand =
parent_->GetEvaluatedLiteralFor(scatter->operand(0));
std::unique_ptr<Literal> reshaped_scatter_indices;
Literal reshaped_scatter_indices;
TF_ASSIGN_OR_RETURN(const Literal& scatter_indices,
ReshapedScatterIndices(dim_numbers.index_vector_dim(),
parent_->GetEvaluatedLiteralFor(
@ -2260,7 +2251,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Initialize the result with the operand. This makes it easier to handle
// the updates even when the indices are repeated.
std::unique_ptr<Literal> result = operand.CloneToUnique();
Literal result = operand.Clone();
HloEvaluator embedded_evaluator;
auto scatter_inner_loop_body =
[&](absl::Span<const int64> update_window_index,
@ -2299,19 +2290,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
auto result_value_literal =
LiteralUtil::CreateR0<ReturnT>(result->Get<ReturnT>(input_index));
LiteralUtil::CreateR0<ReturnT>(result.Get<ReturnT>(input_index));
auto update_value_literal =
LiteralUtil::CreateR0<ReturnT>(updates.Get<ReturnT>(update_index));
std::unique_ptr<Literal> updated_result =
Literal updated_result =
embedded_evaluator
.Evaluate<const Literal*>(
*scatter->to_apply(),
{result_value_literal.get(), update_value_literal.get()})
{&result_value_literal, &update_value_literal})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again on the
// same computation.
embedded_evaluator.ResetVisitStates();
result->Set<ReturnT>(input_index, updated_result->Get<ReturnT>({}));
result.Set<ReturnT>(input_index, updated_result.Get<ReturnT>({}));
return true;
};
@ -2361,7 +2352,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto result = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func));
parent_->evaluated_[slice] = std::move(result);
return Status::OK();
}
@ -2575,7 +2566,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (ShapeUtil::Rank(iota->shape()) > 1) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[iota],
result->Broadcast(iota->shape(), {iota->iota_dimension()}));
result.Broadcast(iota->shape(), {iota->iota_dimension()}));
} else {
TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
parent_->evaluated_[iota] = std::move(result);
@ -2645,9 +2636,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename IndexT>
StatusOr<std::unique_ptr<Literal>> DynamicSlice(
const Literal& operand_literal, const Literal& start_indices_literal,
const Shape& result_shape) {
StatusOr<Literal> DynamicSlice(const Literal& operand_literal,
const Literal& start_indices_literal,
const Shape& result_shape) {
auto start_indices_typed = start_indices_literal.data<IndexT>();
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
@ -2660,9 +2651,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
std::vector<int64> operand_indices(start.size());
auto result = absl::make_unique<Literal>(result_shape);
Literal result(result_shape);
TF_RETURN_IF_ERROR(
result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
for (int64 i = 0; i < operand_indices.size(); ++i) {
CHECK_GE(multi_index[i] + start[i], 0);
operand_indices[i] = multi_index[i] + start[i];
@ -2676,12 +2667,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename IndexT>
StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
const Literal& operand_literal, const Literal& update_literal,
const Literal& start_indices_literal) {
auto result = operand_literal.CloneToUnique();
StatusOr<Literal> DynamicUpdateSlice(const Literal& operand_literal,
const Literal& update_literal,
const Literal& start_indices_literal) {
auto result = operand_literal.Clone();
auto start_indices_typed = start_indices_literal.data<IndexT>();
const auto rank = ShapeUtil::Rank(result->shape());
const auto rank = ShapeUtil::Rank(result.shape());
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
// Clamp the update start indices so the slice is in-bounds w.r.t the
@ -2689,15 +2680,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
for (int64 i = 0; i < rank; ++i) {
start[i] = std::min<int64>(
std::max<int64>(0, start[i]),
result->shape().dimensions(i) - update_literal.shape().dimensions(i));
result.shape().dimensions(i) - update_literal.shape().dimensions(i));
}
std::vector<int64> result_index(rank, 0);
auto func = [&](absl::Span<const int64> update_index) {
std::transform(update_index.begin(), update_index.end(), start.begin(),
result_index.begin(), std::plus<int64>());
result->Set<ReturnT>(result_index,
update_literal.Get<ReturnT>(update_index));
result.Set<ReturnT>(result_index,
update_literal.Get<ReturnT>(update_index));
return true;
};
@ -2710,7 +2701,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return std::move(result);
}
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
StatusOr<Literal> ElementWiseUnaryOp(
HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
const Literal& operand_literal =
@ -2723,7 +2714,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return std::move(result_literal);
}
StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
StatusOr<Literal> ElementWiseBinaryOp(
HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
binary_op) {
@ -2745,10 +2736,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
auto result = absl::make_unique<Literal>(shape);
Literal result(shape);
TF_RETURN_IF_ERROR(
result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ConvertBinaryFunction(binary_op)(
lhs_literal.Get<ReturnT>(multi_index),
rhs_literal.Get<ReturnT>(multi_index));
@ -2757,7 +2748,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename LhsType, typename RhsType, typename EhsType>
StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp(
StatusOr<Literal> ElementwiseTernaryOp(
HloInstruction* instruction,
const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
const auto shape = instruction->shape();
@ -2782,10 +2773,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
auto result = absl::make_unique<Literal>(shape);
Literal result(shape);
TF_RETURN_IF_ERROR(
result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ternary_op(lhs_literal.Get<LhsType>(multi_index),
rhs_literal.Get<RhsType>(multi_index),
ehs_literal.Get<EhsType>(multi_index));

View File

@ -250,7 +250,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.has_literal());
TF_ASSIGN_OR_RETURN(auto literal,
Literal::CreateFromProto(proto.literal()));
instruction = CreateTrace(literal->GetR1U8AsString(), operands(0));
instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
break;
}
case HloOpcode::kFusion: {
@ -527,7 +527,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
std::unique_ptr<Literal> literal) {
Literal literal) {
return absl::make_unique<HloConstantInstruction>(std::move(literal));
}

View File

@ -359,8 +359,7 @@ class HloInstruction {
const string& name);
// Creates a literal constant instruction.
static std::unique_ptr<HloInstruction> CreateConstant(
std::unique_ptr<Literal> literal);
static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
// Creates an Iota instruction.
static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,

View File

@ -845,8 +845,8 @@ std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
}
HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
: HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()),
HloConstantInstruction::HloConstantInstruction(Literal literal)
: HloInstruction(HloOpcode::kConstant, literal.shape()),
literal_(std::move(literal)) {}
HloConstantInstruction::HloConstantInstruction(const Shape& shape)
@ -854,7 +854,7 @@ HloConstantInstruction::HloConstantInstruction(const Shape& shape)
HloInstructionProto HloConstantInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
if (literal_ != nullptr) {
if (literal_.has_value()) {
*proto.mutable_literal() = literal_->ToProto();
}
return proto;
@ -876,7 +876,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
if (!mutable_array_subshape->has_layout() ||
!LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
literal_ = literal_->Relayout(new_layout, shape_index);
*literal_ = literal_->Relayout(new_layout, shape_index);
*mutable_array_subshape->mutable_layout() = new_layout;
}
}
@ -893,7 +893,8 @@ std::unique_ptr<HloInstruction>
HloConstantInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return absl::make_unique<HloConstantInstruction>(literal_->CloneToUnique());
CHECK(literal_.has_value());
return absl::make_unique<HloConstantInstruction>(literal_->Clone());
}
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
@ -901,7 +902,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
CanonicalNameMap* canonical_name_map) const {
string operands;
// For constants, show the actual value in place of an empty operand list.
if (literal_ != nullptr &&
if (literal_.has_value() &&
((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
options.print_large_constants())) {
// Literal::ToString emits multidimensional arrays over multiple
@ -936,7 +937,7 @@ HloTraceInstruction::HloTraceInstruction(const string& tag,
HloInstructionProto HloTraceInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
*proto.mutable_literal() = literal_->ToProto();
*proto.mutable_literal() = literal_.ToProto();
return proto;
}

View File

@ -580,13 +580,13 @@ class HloSliceInstruction : public HloInstruction {
class HloConstantInstruction : public HloInstruction {
public:
explicit HloConstantInstruction(std::unique_ptr<Literal> literal);
explicit HloConstantInstruction(Literal literal);
// Used when the literal is too large and dropped.
explicit HloConstantInstruction(const Shape& shape);
// Returns the literal associated with this instruction.
const Literal& literal() const { return *literal_; }
// Returns whether there is literal associated with this instruction.
bool HasLiteral() const { return literal_ != nullptr; }
bool HasLiteral() const { return literal_.has_value(); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@ -610,15 +610,14 @@ class HloConstantInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// TODO(b/36360764): Remove unique_ptr wrapping.
std::unique_ptr<Literal> literal_;
absl::optional<Literal> literal_;
};
class HloTraceInstruction : public HloInstruction {
public:
explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
// Returns a tag to be used in tracing.
string TracingTag() const { return literal_->GetR1U8AsString(); }
string TracingTag() const { return literal_.GetR1U8AsString(); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@ -631,8 +630,7 @@ class HloTraceInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// TODO(b/36360764): Remove unique_ptr wrapping.
std::unique_ptr<Literal> literal_;
Literal literal_;
};
class HloFusionInstruction : public HloInstruction {

View File

@ -105,16 +105,13 @@ class HloParser {
string* root_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
bool ParseControlPredecessors(HloInstruction* instruction);
bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape);
bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
bool ParseSparseLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape);
bool ParseLiteral(Literal* literal, const Shape& shape);
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
bool ParseDenseLiteral(Literal* literal, const Shape& shape);
bool ParseSparseLiteral(Literal* literal, const Shape& shape);
template <typename LiteralNativeT>
bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
const Shape& shape);
bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape);
// Sets the sub-value of literal at the given index to the given value. The
// literal's shape must have the default layout.
@ -577,7 +574,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kConstant: {
std::unique_ptr<Literal> literal;
Literal literal;
if (!ParseToken(TokKind::kLparen,
"expects '(' before constant literal") ||
!ParseLiteral(&literal, shape) ||
@ -1810,8 +1807,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) {
// literal
// ::= tuple
// ::= non_tuple
bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape) {
bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) {
return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape)
: ParseNonTupleLiteral(literal, shape);
}
@ -1821,8 +1817,7 @@ bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
// literal_list
// ::= /*empty*/
// ::= literal (',' literal)*
bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape) {
bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) {
if (!EatShapeAndCheckCompatible(shape)) {
return TokenError(StrCat("expects tuple constant in shape ",
ShapeUtil::HumanString(shape)));
@ -1830,8 +1825,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
return false;
}
std::vector<std::unique_ptr<Literal>> elements(
ShapeUtil::TupleElementCount(shape));
std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
@ -1857,8 +1851,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
// ::= rank01
// ::= rank2345
// rank2345 ::= shape sparse_or_nested_array
bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape) {
bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
if (LayoutUtil::IsSparseArray(shape)) {
return ParseSparseLiteral(literal, shape);
}
@ -1867,8 +1860,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
return ParseDenseLiteral(literal, shape);
}
bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape) {
bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
const tensorflow::int64 rank = ShapeUtil::Rank(shape);
if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
return false;
@ -1962,7 +1954,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
// TODO(congliu): bool type literals with rank >= 1 are actually
// printed in a compact form instead of "true" or "false". Fix that.
if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
linear_index++, literal->get())) {
linear_index++, literal)) {
return false;
}
lexer_.Lex();
@ -1973,7 +1965,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
return Error(loc, StrCat("expects integer for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
if (!SetValueInLiteral(value, linear_index++, literal->get())) {
if (!SetValueInLiteral(value, linear_index++, literal)) {
return false;
}
} else if (primitive_util::IsFloatingPointType(shape.element_type())) {
@ -1984,7 +1976,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
loc, StrCat("expect floating point value for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
if (!SetValueInLiteral(value, linear_index++, literal->get())) {
if (!SetValueInLiteral(value, linear_index++, literal)) {
return false;
}
} else {
@ -1996,12 +1988,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
} // end of switch
} while (nest_level > 0);
*literal = (*literal)->Relayout(shape.layout());
*literal = literal->Relayout(shape.layout());
return true;
}
bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape) {
bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) {
if (!EatShapeAndCheckCompatible(shape)) {
return false;
}
@ -2041,13 +2032,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
}
template <typename LiteralNativeT>
bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
const Shape& shape) {
bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) {
std::vector<tensorflow::int64> index;
tensorflow::int64 rank = ShapeUtil::Rank(shape);
*literal = absl::make_unique<Literal>(shape);
*literal = Literal(shape);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of a sparse literal")) {
@ -2121,7 +2111,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
return false;
}
if ((*literal)->sparse_element_count() + 1 ==
if (literal->sparse_element_count() + 1 ==
LayoutUtil::MaxSparseElements(shape.layout())) {
return Error(
lexer_.GetLoc(),
@ -2129,10 +2119,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
ShapeUtil::HumanStringWithLayout(shape)));
}
(*literal)->AppendSparseElement(index, value);
literal->AppendSparseElement(index, value);
}
(*literal)->SortSparseElements();
literal->SortSparseElements();
return true;
}

View File

@ -118,16 +118,16 @@ StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
}
StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
const absl::Span<const std::unique_ptr<Literal>> literals) {
const absl::Span<const Literal> literals) {
std::vector<const Literal*> literal_pointers;
literal_pointers.reserve(literals.size());
for (const auto& literal : literals) {
literal_pointers.push_back(literal.get());
literal_pointers.push_back(&literal);
}
return TransferLiteralsToDevice(literal_pointers);
}
StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
StatusOr<Literal> HloRunner::TransferLiteralFromDevice(
const ShapedBuffer& buffer) {
TF_ASSIGN_OR_RETURN(
auto stream, backend().BorrowStream(backend().default_stream_executor()));
@ -135,7 +135,7 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
buffer);
}
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
StatusOr<Literal> HloRunner::Execute(
std::unique_ptr<HloModule> module,
const absl::Span<const Literal* const> arguments, bool run_hlo_passes,
ExecutionProfile* profile) {
@ -150,15 +150,15 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
return TransferLiteralFromDevice(result);
}
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
std::unique_ptr<HloModule> module,
const absl::Span<const std::unique_ptr<Literal>> arguments,
bool run_hlo_passes, ExecutionProfile* profile) {
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
const absl::Span<const Literal> arguments,
bool run_hlo_passes,
ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
argument_pointers.push_back(argument.get());
argument_pointers.push_back(&argument);
}
return Execute(
/*module=*/std::move(module),
@ -204,7 +204,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
/*profile=*/profile);
}
StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options) {
TF_ASSIGN_OR_RETURN(
@ -290,9 +290,9 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
VLOG(1) << "Starting outfeed on device " << device;
for (int64 step = 1;
options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
auto literal = absl::make_unique<Literal>();
Literal literal;
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
executor, options.outfeed_shape, literal.get()));
executor, options.outfeed_shape, &literal));
if (options.outfeed_values != nullptr) {
options.outfeed_values->push_back(std::move(literal));
}
@ -310,10 +310,10 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
argument_buffer_slices));
LOG(INFO) << "Replicated execution terminated";
std::vector<std::unique_ptr<Literal>> exec_results;
std::vector<Literal> exec_results;
for (int64 i = 0; i < options.num_replicas; ++i) {
TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
TF_ASSIGN_OR_RETURN(Literal literal,
backend().transfer_manager()->TransferLiteralFromDevice(
streams[i].get(), results[i]));
exec_results.push_back(std::move(literal));

View File

@ -72,7 +72,7 @@ class HloRunner {
// A pointer to a vector where the outfeed values will be stored. If
// nullptr, the values will be read and discarded.
std::vector<std::unique_ptr<Literal>>* outfeed_values = nullptr;
std::vector<Literal>* outfeed_values = nullptr;
// Whether the HLO passes should be run on the input module. Usually
// saved modules are coming from after the HLO pass pipeline, so triggering
@ -106,24 +106,23 @@ class HloRunner {
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
const absl::Span<const Literal* const> literals);
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
const absl::Span<const std::unique_ptr<Literal>> literals);
StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
const ShapedBuffer& buffer);
const absl::Span<const Literal> literals);
StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer);
// Executes the given module with given literals as input and returns the
// result as a Literal.
//
// If run_hlo_passes is false, the module will be executed without Hlo
// optimization.
StatusOr<std::unique_ptr<Literal>> Execute(
std::unique_ptr<HloModule> module,
const absl::Span<const Literal* const> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
const absl::Span<const Literal* const> arguments,
bool run_hlo_passes = true,
ExecutionProfile* profile = nullptr);
StatusOr<std::unique_ptr<Literal>> Execute(
std::unique_ptr<HloModule> module,
const absl::Span<const std::unique_ptr<Literal>> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
const absl::Span<const Literal> arguments,
bool run_hlo_passes = true,
ExecutionProfile* profile = nullptr);
// As Execute(), but accepts and returns device buffers instead of host
// buffers.
@ -140,7 +139,7 @@ class HloRunner {
// Executes a given HLO module into a set of replicas, and returns a map
// with the replica number as key, and the corresponding returned literal as
// value.
StatusOr<std::vector<std::unique_ptr<Literal>>> ExecuteReplicated(
StatusOr<std::vector<Literal>> ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options);

View File

@ -290,8 +290,8 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
padding_config.add_dimensions()->set_interior_padding(-1);
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {100}), param,
builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(F32).CloneToUnique())),
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
padding_config));
auto module = CreateNewModule();
@ -314,8 +314,8 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
padding_config.add_dimensions()->set_interior_padding(-1);
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {100}), param,
builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(F32).CloneToUnique())),
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
padding_config));
auto module = CreateNewModule();

View File

@ -918,7 +918,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
// inner_broadcast_result is the Broadcast'(Const0) bit in
// BinaryOp(Broadcast'(Const0), Const1)
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Literal> inner_broadcast_result,
Literal inner_broadcast_result,
broadcast_const_operand->literal().Broadcast(
scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
@ -928,12 +928,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
TF_ASSIGN_OR_RETURN(
literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
opcode, scalar_indexed_const->literal(), *inner_broadcast_result)));
opcode, scalar_indexed_const->literal(), inner_broadcast_result)));
} else {
TF_ASSIGN_OR_RETURN(
literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
opcode, *inner_broadcast_result, scalar_indexed_const->literal())));
opcode, inner_broadcast_result, scalar_indexed_const->literal())));
}
ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);

View File

@ -347,21 +347,19 @@ class IndexedArrayAnalysis {
}
}
Literal* TakeOwnership(std::unique_ptr<Literal> literal) {
Literal* TakeOwnership(Literal literal) {
owned_literals_.push_back(std::move(literal));
return owned_literals_.back().get();
return &owned_literals_.back();
}
StatusOr<Literal*> TakeOwnership(
StatusOr<std::unique_ptr<Literal>> literal_or_error) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
std::move(literal_or_error));
StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) {
TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error));
owned_literals_.push_back(std::move(literal));
return owned_literals_.back().get();
return &owned_literals_.back();
}
std::vector<std::unique_ptr<Array>> owned_tensors_;
std::vector<std::unique_ptr<Literal>> owned_literals_;
std::vector<Literal> owned_literals_;
tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
};

View File

@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
// Test that `constant` function is changed to `broadcast`.
@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}

View File

@ -73,30 +73,29 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
// Transform the ShapedBuffer arguments into literals which the evaluator
// consumes.
std::vector<std::unique_ptr<Literal>> arg_literals;
std::vector<Literal> arg_literals;
for (int64 p = 0; p < computation->num_parameters(); ++p) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> arg_literal,
TF_ASSIGN_OR_RETURN(Literal arg_literal,
transfer_manager->TransferLiteralFromDevice(
run_options->stream(), *arguments[p]));
arg_literals.push_back(std::move(arg_literal));
}
// Execute the graph using the HloEvaluator.
std::unique_ptr<Literal> result_literal;
Literal result_literal;
{
tensorflow::mutex_lock lock(evaluator_lock_);
TF_ASSIGN_OR_RETURN(result_literal,
evaluator_->Evaluate<std::unique_ptr<Literal>>(
*computation, arg_literals));
TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate<Literal>(
*computation, arg_literals));
}
// Transform the result literal back into a ShapedBuffer.
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
transfer_manager->AllocateScopedShapedBuffer(
result_literal->shape(), run_options->allocator(),
result_literal.shape(), run_options->allocator(),
executor->device_ordinal()));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
run_options->stream(), *result_literal, result));
run_options->stream(), result_literal, result));
uint64 end_micros = tensorflow::Env::Default()->NowMicros();

View File

@ -145,7 +145,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
Shape ashape = constant_literal1->shape();
Shape ashape = constant_literal1.shape();
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(constant_literal1)));

View File

@ -68,9 +68,9 @@ Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
module->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Literal> literal,
Literal literal,
transfer_manager->TransferLiteralFromDevice(stream, *argument));
*module->add_arguments() = literal->ToProto();
*module->add_arguments() = literal.ToProto();
}
return Status::OK();
}
@ -80,9 +80,9 @@ Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
TransferManager* transfer_manager, HloSnapshot* module) {
module->clear_result();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Literal> literal,
Literal literal,
transfer_manager->TransferLiteralFromDevice(stream, result));
*module->mutable_result() = literal->ToProto();
*module->mutable_result() = literal.ToProto();
return Status::OK();
}
@ -928,16 +928,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg,
shaped_buffer->device_ordinal()));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Literal> result_literal,
Literal result_literal,
execute_backend_->transfer_manager()->TransferLiteralFromDevice(
stream.get(), *shaped_buffer));
if (LayoutUtil::LayoutsInShapesEqual(*return_shape,
result_literal->shape())) {
*result->mutable_literal() = result_literal->ToProto();
if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) {
*result->mutable_literal() = result_literal.ToProto();
} else {
*result->mutable_literal() =
result_literal->Relayout(*return_shape)->ToProto();
result_literal.Relayout(*return_shape).ToProto();
}
return Status::OK();
}
@ -959,9 +958,9 @@ std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice(
Status Service::TransferToServer(const TransferToServerRequest* arg,
TransferToServerResponse* result) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
TF_ASSIGN_OR_RETURN(Literal literal,
Literal::CreateFromProto(arg->literal()));
const Shape& shape = literal->shape();
const Shape& shape = literal.shape();
std::vector<se::StreamExecutor*> replicas;
if (arg->has_device_handle()) {
@ -983,7 +982,7 @@ Status Service::TransferToServer(const TransferToServerRequest* arg,
TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralToDevice(
stream.get(), *literal, shaped_buffer));
stream.get(), literal, shaped_buffer));
replicated_buffers.emplace_back(std::move(shaped_buffer));
}
TF_ASSIGN_OR_RETURN(*result->mutable_data(),
@ -1018,10 +1017,10 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
executor = replicas[arg->replica_id()];
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
TF_ASSIGN_OR_RETURN(Literal literal,
Literal::CreateFromProto(arg->literal()));
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
executor, *literal);
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
literal);
}
Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
@ -1049,8 +1048,8 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
executor, arg->shape_with_layout(), *literal));
*result->mutable_literal() = literal->ToProto();
executor, arg->shape_with_layout(), literal));
*result->mutable_literal() = literal.ToProto();
return Status::OK();
}
@ -1085,18 +1084,17 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
HloModule::CreateFromProto(arg->computation(), config));
HloEvaluator evaluator;
TF_ASSIGN_OR_RETURN(auto result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
*module, /*arg_literals=*/{}));
TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate<Literal>(
*module, /*arg_literals=*/{}));
// Since the result layout is non-effective to the Evaluator results, explicit
// relayout here.
//
// TODO(b/77824332): Make HloEvaluator take care of the re-layout.
if (arg->has_output_layout()) {
result_literal = result_literal->Relayout(arg->output_layout());
result_literal = result_literal.Relayout(arg->output_layout());
}
*result->mutable_literal() = result_literal->ToProto();
*result->mutable_literal() = result_literal.ToProto();
return Status::OK();
}

View File

@ -42,9 +42,9 @@ TransferManager::GetPlatformTransferManagers() {
return r;
}
StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer) {
StatusOr<std::unique_ptr<Literal>> ret;
StatusOr<Literal> ret;
se::Stream* substream = stream->GetOrCreateSubStream();
substream->ThenWaitFor(stream);
@ -63,7 +63,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
if (!s.ok()) {
return s;
}
return absl::make_unique<Literal>(std::move(literal));
return std::move(literal);
}
Status TransferManager::TransferLiteralFromDevice(
@ -99,10 +99,10 @@ Status TransferManager::TransferLiteralToDevice(
return substream->BlockHostUntilDone();
}
StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
StatusOr<Literal> TransferManager::TransferArrayFromDevice(
se::Stream* stream, const Shape& shape,
const se::DeviceMemoryBase& source) {
StatusOr<std::unique_ptr<Literal>> ret;
StatusOr<Literal> ret;
// Implement the synchronous version by waiting on the asynchronous version.
// Use a substream so that if we are called from a HostCallback we don't
// deadlock.
@ -122,7 +122,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
if (!s.ok()) {
return s;
}
return absl::make_unique<Literal>(std::move(literal));
return std::move(literal);
}
Status TransferManager::TransferArrayToDevice(

View File

@ -57,7 +57,7 @@ class TransferManager {
// without waiting for any other operation on a stream to complete.
//
// This function should be avoided in favor of the asynchronous version below.
virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
virtual StatusOr<Literal> TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer);
virtual Status TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer,
@ -113,9 +113,9 @@ class TransferManager {
Status TransferArrayToDeviceAsync(se::Stream* stream,
const LiteralSlice& literal,
const se::DeviceMemoryBase& dest);
StatusOr<std::unique_ptr<Literal>> TransferArrayFromDevice(
se::Stream* stream, const Shape& shape,
const se::DeviceMemoryBase& source);
StatusOr<Literal> TransferArrayFromDevice(se::Stream* stream,
const Shape& shape,
const se::DeviceMemoryBase& source);
// Transfers the given literal into the Infeed interface of the device,
// using the given executor.

View File

@ -555,10 +555,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
// Construct a tuple constant and kCopy it. Verify the points-to set of the
// copy correctly correctly points into the nested elements of the constant.
auto builder = HloComputation::Builder(TestName());
auto tuple_constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::MakeTuple(
{LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
LiteralUtil::CreateR1<float>({2.0, 42}).get()})));
Literal elements[] = {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
LiteralUtil::CreateR1<float>({2.0, 42})};
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));

View File

@ -183,8 +183,7 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
HloEvaluator evaluator(/*max_loop_iterations=*/0);
auto* while_init = while_op->mutable_operand(0);
auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
StatusOr<std::unique_ptr<Literal>> indvar_init_result =
evaluator.Evaluate(indvar_init);
StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
if (!indvar_init_result.ok()) {
VLOG(2) << "Couldn't evaluate induction variable init: "
<< indvar_init_result.status();
@ -197,31 +196,27 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
// The initial value of the induction variable.
std::unique_ptr<Literal> indvar_iter_val =
std::move(indvar_init_result).ValueOrDie();
Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
for (int64 trip_count = 0; trip_count != max_value_returned + 1;
++trip_count) {
auto* while_cond = while_op->while_condition();
auto* while_cond_root = while_cond->root_instruction();
auto* while_cond_indvar = NonConstantOperand(while_cond_root);
StatusOr<std::unique_ptr<Literal>> result =
evaluator.EvaluateWithSubstitutions(
while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}});
StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
if (!result.ok()) {
VLOG(2) << "Couldn't evaluate while cond: " << result.status();
return nullopt;
}
if (result.ValueOrDie()->data<bool>() == absl::Span<const bool>{false}) {
if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
VLOG(2) << "Loop has static trip count of " << trip_count;
return trip_count;
}
// Calculate the value of the induction variable after one iteration of the
// loop, and check whether the while condition is true with this new value.
StatusOr<std::unique_ptr<Literal>> indvar_next_result =
evaluator.EvaluateWithSubstitutions(
while_body_indvar_update,
{{while_body_indvar, indvar_iter_val.get()}});
StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
if (!indvar_next_result.ok()) {
VLOG(2) << "Couldn't evaluate induction variable update: "
<< indvar_next_result.status();

View File

@ -41,7 +41,6 @@ limitations under the License.
namespace xla {
namespace {
class ArrayElementwiseOpTest : public ClientLibraryTestBase {
public:
ErrorSpec error_spec_{0.0001, 0.0001};
@ -227,10 +226,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
0x8000000000000000LL,
0x8000000000000000LL,
1};
std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
std::unique_ptr<GlobalData> lhs_data =
client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
std::vector<uint64> rhs{1,
0x7FFFFFFFFFFFFFFLL,
@ -241,10 +240,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
0,
1,
0x8000000000000000LL};
std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
std::unique_ptr<GlobalData> rhs_data =
client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
Add(lhs_param, rhs_param);
@ -267,10 +266,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
1,
0,
-1};
std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
Literal lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
std::unique_ptr<GlobalData> lhs_data =
client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
std::vector<int64> rhs{-1,
0,
@ -280,10 +279,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
0x7FFFFFFFFFFFFFFLL,
0x7FFFFFFFFFFFFFFFLL,
0x7FFFFFFFFFFFFFFFLL};
std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
Literal rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
std::unique_ptr<GlobalData> rhs_data =
client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
Sub(lhs_param, rhs_param);
@ -299,16 +298,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) {
XlaBuilder b(TestName());
std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)};
std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)};
std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
Lt(lhs_param, rhs_param);
ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)});
ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)});
}
TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
@ -321,16 +320,16 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
b_values.push_back(2 * i / static_cast<float>(count + 2));
}
std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({a_values});
Literal a_literal = LiteralUtil::CreateR1<float>({a_values});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
client_->TransferToServer(a_literal).ConsumeValueOrDie();
auto a_constant = ConstantR1<float>(&builder, a_values);
auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param");
auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param");
std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR1<float>({b_values});
Literal b_literal = LiteralUtil::CreateR1<float>({b_values});
std::unique_ptr<GlobalData> b_data =
client_->TransferToServer(*b_literal).ConsumeValueOrDie();
auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param");
client_->TransferToServer(b_literal).ConsumeValueOrDie();
auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param");
auto b_param = ConstantR1<float>(&builder, b_values);
auto sum1 = Add(a_constant, b_constant);
@ -1422,12 +1421,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::unique_ptr<Literal> param_literal = LiteralUtil::CreateR1<float>(values);
Literal param_literal = LiteralUtil::CreateR1<float>(values);
std::unique_ptr<GlobalData> param_data =
client_->TransferToServer(*param_literal).ConsumeValueOrDie();
client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto sum = ConstantR0<float>(&b, 0.0f);
auto param = Parameter(&b, 0, param_literal->shape(), "param");
auto param = Parameter(&b, 0, param_literal.shape(), "param");
for (float exponent : exponents) {
sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
}
@ -1450,14 +1449,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
client_->TransferToServer(literal0).ConsumeValueOrDie();
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
client_->TransferToServer(literal1).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Pow(Exp(param0), param1);
std::vector<float> expected(values0.size());
@ -1475,14 +1474,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
client_->TransferToServer(literal0).ConsumeValueOrDie();
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
client_->TransferToServer(literal1).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Log(Pow(param0, param1));
std::vector<float> expected(values0.size());
@ -1500,14 +1499,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
client_->TransferToServer(literal0).ConsumeValueOrDie();
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
client_->TransferToServer(literal1).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Mul(Exp(param0), Exp(param1));
std::vector<float> expected(values0.size());
@ -1525,14 +1524,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
client_->TransferToServer(literal0).ConsumeValueOrDie();
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
client_->TransferToServer(literal1).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Div(param0, Exp(param1));
std::vector<float> expected(values0.size());
@ -1551,20 +1550,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
client_->TransferToServer(literal0).ConsumeValueOrDie();
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
client_->TransferToServer(literal1).ConsumeValueOrDie();
std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
client_->TransferToServer(*literal2).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
client_->TransferToServer(literal2).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(Div(param0, param1), param2);
std::vector<float> expected(values0.size());
@ -1583,21 +1582,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
client_->TransferToServer(literal0).ConsumeValueOrDie();
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
client_->TransferToServer(literal1).ConsumeValueOrDie();
std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
client_->TransferToServer(*literal2).ConsumeValueOrDie();
client_->TransferToServer(literal2).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(param0, Div(param1, param2));
std::vector<float> expected(values0.size());
@ -1616,21 +1615,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
client_->TransferToServer(literal0).ConsumeValueOrDie();
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
client_->TransferToServer(literal1).ConsumeValueOrDie();
std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
client_->TransferToServer(*literal2).ConsumeValueOrDie();
client_->TransferToServer(literal2).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(param0, Pow(param1, param2));
std::vector<float> expected(values0.size());
@ -1650,26 +1649,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
client_->TransferToServer(literal0).ConsumeValueOrDie();
std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
client_->TransferToServer(literal1).ConsumeValueOrDie();
std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
client_->TransferToServer(*literal2).ConsumeValueOrDie();
client_->TransferToServer(literal2).ConsumeValueOrDie();
std::unique_ptr<Literal> literal3 = LiteralUtil::CreateR1<float>(values3);
Literal literal3 = LiteralUtil::CreateR1<float>(values3);
std::unique_ptr<GlobalData> data3 =
client_->TransferToServer(*literal3).ConsumeValueOrDie();
client_->TransferToServer(literal3).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
auto param3 = Parameter(&b, 3, literal3->shape(), "param2");
auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
auto param3 = Parameter(&b, 3, literal3.shape(), "param2");
Div(Div(param0, param1), Div(param2, param3));
std::vector<float> expected(values0.size());
@ -2096,18 +2095,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
Literal param0_literal =
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
Literal param1_literal =
LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p0, p1);
ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
@ -2118,18 +2117,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
Literal param0_literal =
LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
Literal param1_literal =
LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
client_->TransferToServer(param1_literal).ConsumeValueOrDie();
auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p0, p1);
Array3D<float> expected(0, 7, 0);
@ -2140,13 +2139,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
Literal param0_literal =
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
auto p = Parameter(&builder, 0, param0_literal->shape(), "param0");
auto p = Parameter(&builder, 0, param0_literal.shape(), "param0");
Add(a, p);
ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
@ -2206,9 +2205,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31,
-0.79, 1.41, 1.21, 1.05});
TF_ASSERT_OK_AND_ASSIGN(auto input_data,
client_->TransferToServer(*input_literal));
client_->TransferToServer(input_literal));
auto input = Parameter(&builder, 0, input_literal->shape(), "input");
auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Tanh(input);
ComputeAndCompareR1<float>(
@ -2239,7 +2238,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
// Just to help make sense of the scales here -- exp(89) saturates float32 and
// exp(-10) is smaller than our error spec.
std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
Literal input_literal = LiteralUtil::CreateR1<float>(
{1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31,
-1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5,
-193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4,
@ -2252,16 +2251,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3,
86.4, 86.5, 87.6, 87.7, 87.8, 87.9});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
client_->TransferToServer(*input_literal));
client_->TransferToServer(input_literal));
auto input = Parameter(&builder, 0, input_literal->shape(), "input");
auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Exp(input);
std::vector<float> expected_result;
int64 input_size = input_literal->shape().dimensions(0);
int64 input_size = input_literal.shape().dimensions(0);
expected_result.reserve(input_size);
for (int64 i = 0; i < input_size; i++) {
expected_result.push_back(std::exp(input_literal->Get<float>({i})));
expected_result.push_back(std::exp(input_literal.Get<float>({i})));
}
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
@ -2273,7 +2272,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
// implementation on XLA CPU.
XlaBuilder builder(TestName());
std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
Literal input_literal = LiteralUtil::CreateR1<float>(
{-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198,
-167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9,
198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04,
@ -2290,16 +2289,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33,
1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
client_->TransferToServer(*input_literal));
client_->TransferToServer(input_literal));
auto input = Parameter(&builder, 0, input_literal->shape(), "input");
auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Log(input);
std::vector<float> expected_result;
int64 input_size = input_literal->shape().dimensions(0);
int64 input_size = input_literal.shape().dimensions(0);
expected_result.reserve(input_size);
for (int64 i = 0; i < input_size; i++) {
expected_result.push_back(std::log(input_literal->Get<float>({i})));
expected_result.push_back(std::log(input_literal.Get<float>({i})));
}
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
@ -2465,10 +2464,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
Tuple(&builder, {cmp_dim_0, cmp_dim_1});
auto expected = LiteralUtil::MakeTuple(
{LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}).get(),
LiteralUtil::CreateR2<bool>({{true, false}, {false, false}}).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}),
LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})});
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
@ -2821,10 +2820,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
std::iota(r1.begin(), r1.end(), 1.0);
XlaBuilder builder(TestName());
std::unique_ptr<Literal> a_literal =
LiteralUtil::CreateR4FromArray4DWithLayout(
r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
auto a = ConstantLiteral(&builder, *a_literal);
Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
auto a = ConstantLiteral(&builder, a_literal);
auto b = ConstantR1<float>(&builder, r1);
Add(a, b, {1});
@ -2886,11 +2884,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
XlaBuilder builder(TestName());
auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
auto x = Parameter(&builder, 0, x_literal->shape(), "x");
auto y = Parameter(&builder, 1, y_literal->shape(), "y");
auto x = Parameter(&builder, 0, x_literal.shape(), "x");
auto y = Parameter(&builder, 1, y_literal.shape(), "y");
auto slice = Slice(x, {1}, {2}, {1});
Sub(slice, y);

View File

@ -63,7 +63,7 @@ class BatchNormalizationTest
{5.0f, 4.4f}, // p2
});
input_array_.FillWithPZ(pz);
input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_));
input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_);
CHECK_EQ(kSamples, input_array_.planes());
CHECK_EQ(kZ, input_array_.depth());
CHECK_EQ(kY, input_array_.height());
@ -242,14 +242,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
auto expected = LiteralUtil::MakeTuple(
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
{{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
.get(),
LiteralUtil::CreateR1<float>({4, 5}).get(),
LiteralUtil::CreateR1<float>({5, 5}).get()});
{{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}),
LiteralUtil::CreateR1<float>({4, 5}),
LiteralUtil::CreateR1<float>({5, 5})});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
@ -267,14 +266,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
auto expected = LiteralUtil::MakeTuple(
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
{{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
.get(),
LiteralUtil::CreateR1<float>({4, 5}).get(),
LiteralUtil::CreateR1<float>({5, 5}).get()});
{{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}),
LiteralUtil::CreateR1<float>({4, 5}),
LiteralUtil::CreateR1<float>({5, 5})});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
@ -298,13 +296,12 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/1, kFeatureIndex);
auto expected = LiteralUtil::MakeTuple(
{LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
.get(),
LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)),
LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)),
LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f))});
ComputeAndCompareTuple(&builder, *expected,
ComputeAndCompareTuple(&builder, expected,
{operand.get(), scale.get(), offset.get()},
ErrorSpec(0.1));
}
@ -331,14 +328,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/-100, kFeatureIndex);
auto expected = LiteralUtil::MakeTuple(
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR3FromArray3D<float>(
{{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
.get(),
LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
{{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}),
LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)),
LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f))});
ComputeAndCompareTuple(&builder, *expected,
ComputeAndCompareTuple(&builder, expected,
{operand.get(), scale.get(), offset.get()},
ErrorSpec(0.1));
}
@ -363,14 +359,13 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
auto expected = LiteralUtil::MakeTuple(
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
{{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
.get(),
LiteralUtil::CreateR1<float>({0, 0}).get(),
LiteralUtil::CreateR1<float>({16, 20}).get()});
{{{1.f}, {1.f}}, {{3.f}, {3.f}}}}),
LiteralUtil::CreateR1<float>({0, 0}),
LiteralUtil::CreateR1<float>({16, 20})});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
struct BatchNormTestParam {
@ -522,22 +517,22 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
Parameter(&builder, 0, input_literal->shape(), "input");
Parameter(&builder, 0, input_literal.shape(), "input");
auto scale_activations =
Parameter(&builder, 1, scale_literal->shape(), "offset");
Parameter(&builder, 1, scale_literal.shape(), "offset");
auto offset_activations =
Parameter(&builder, 2, offset_literal->shape(), "scale");
Parameter(&builder, 2, offset_literal.shape(), "scale");
auto expected = LiteralUtil::MakeTuple(
{expected_normalized.get(), LiteralUtil::CreateR1<float>(mean).get(),
LiteralUtil::CreateR1<float>(var).get()});
auto expected = LiteralUtil::MakeTupleFromSlices(
{expected_normalized, LiteralUtil::CreateR1<float>(mean),
LiteralUtil::CreateR1<float>(var)});
std::unique_ptr<GlobalData> input_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> offset_data =
client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
client_->TransferToServer(offset_literal).ConsumeValueOrDie();
BatchNormTraining(input_activations, scale_activations, offset_activations,
epsilon, feature_index);
@ -547,7 +542,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
// testcase.
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
ComputeAndCompareTuple(
&builder, *expected,
&builder, expected,
{input_data.get(), scale_data.get(), offset_data.get()},
ErrorSpec(0.01, 1));
}
@ -622,27 +617,27 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
Parameter(&builder, 0, input_literal->shape(), "input");
Parameter(&builder, 0, input_literal.shape(), "input");
auto scale_activations =
Parameter(&builder, 1, scale_literal->shape(), "offset");
Parameter(&builder, 1, scale_literal.shape(), "offset");
auto offset_activations =
Parameter(&builder, 2, offset_literal->shape(), "scale");
auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean");
Parameter(&builder, 2, offset_literal.shape(), "scale");
auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean");
auto variance_activations =
Parameter(&builder, 4, var_literal->shape(), "variance");
Parameter(&builder, 4, var_literal.shape(), "variance");
Array4D<float> expected = normalized;
std::unique_ptr<GlobalData> input_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> offset_data =
client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
client_->TransferToServer(offset_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> mean_data =
client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
client_->TransferToServer(mean_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> variance_data =
client_->TransferToServer(*var_literal).ConsumeValueOrDie();
client_->TransferToServer(var_literal).ConsumeValueOrDie();
BatchNormInference(input_activations, scale_activations, offset_activations,
mean_activations, variance_activations, epsilon,
@ -811,40 +806,37 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
auto grad_output_literal =
LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
auto input_parameter =
Parameter(&builder, 0, input_literal->shape(), "input");
auto scale_parameter =
Parameter(&builder, 1, scale_literal->shape(), "scale");
auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean");
auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance");
auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input");
auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale");
auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean");
auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance");
auto grad_output_parameter =
Parameter(&builder, 4, grad_output_literal->shape(), "grad_output");
Parameter(&builder, 4, grad_output_literal.shape(), "grad_output");
std::unique_ptr<GlobalData> input_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> mean_data =
client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
client_->TransferToServer(mean_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> var_data =
client_->TransferToServer(*var_literal).ConsumeValueOrDie();
client_->TransferToServer(var_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> grad_output_data =
client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
client_->TransferToServer(grad_output_literal).ConsumeValueOrDie();
BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter,
grad_output_parameter, epsilon, feature_index);
auto expected =
LiteralUtil::MakeTuple({expected_grad_activation.get(),
LiteralUtil::CreateR1<float>(grad_scale).get(),
LiteralUtil::CreateR1<float>(grad_offset).get()});
auto expected = LiteralUtil::MakeTupleFromSlices(
{expected_grad_activation, LiteralUtil::CreateR1<float>(grad_scale),
LiteralUtil::CreateR1<float>(grad_offset)});
// Run all HLO passes during this test. In particular, ClientLibraryTestBase
// disables constant folding, but we want it enabled for our zero-sized tensor
// testcase.
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
ComputeAndCompareTuple(&builder, *expected,
ComputeAndCompareTuple(&builder, expected,
{input_data.get(), scale_data.get(), mean_data.get(),
var_data.get(), grad_output_data.get()},
ErrorSpec(0.01, 1));

View File

@ -95,22 +95,19 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex);
auto expected = LiteralUtil::MakeTuple(
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-1.6875f)},
{static_cast<bfloat16>(-2.04f)}},
{{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}},
{{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
{{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
.get(),
{{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}}),
LiteralUtil::CreateR1<bfloat16>(
{static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
.get(),
{static_cast<bfloat16>(4), static_cast<bfloat16>(5)}),
LiteralUtil::CreateR1<bfloat16>(
{static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
.get()});
{static_cast<bfloat16>(5), static_cast<bfloat16>(5)})});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01, 0.02));
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02));
}
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
@ -139,21 +136,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
auto expected = LiteralUtil::MakeTuple(
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
{{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
.get(),
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}}),
LiteralUtil::CreateR1<bfloat16>(
{static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
.get(),
{static_cast<bfloat16>(0), static_cast<bfloat16>(0)}),
LiteralUtil::CreateR1<bfloat16>(
{static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
.get()});
{static_cast<bfloat16>(16), static_cast<bfloat16>(20)})});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
}
} // namespace

View File

@ -60,10 +60,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
float end, int seed) {
*r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r3_array->FillRandom(start, end, seed);
auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout(
auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r3_global_data =
client_->TransferToServer(*r3_data).ConsumeValueOrDie();
client_->TransferToServer(r3_data).ConsumeValueOrDie();
return r3_global_data;
}
@ -74,10 +74,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
float end, int seed) {
*r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r2_array->FillRandom(start, end, seed);
auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout(
auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r2_global_data =
client_->TransferToServer(*r2_data).ConsumeValueOrDie();
client_->TransferToServer(r2_data).ConsumeValueOrDie();
return r2_global_data;
}
@ -293,7 +293,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});
@ -301,7 +301,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
{{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
struct R3ImplicitBroadcastSpec {
@ -370,8 +370,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
}
auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
ComputeAndCompareLiteral(
&builder, *expected,
{r3_implicit_global_data.get(), r3_global_data.get()},
&builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()},
ErrorSpec(1e-7, 1e-7));
}
@ -395,89 +394,89 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()},
ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
XlaBuilder b(TestName());
auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}}));
auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}}));
auto r3 = ConstantLiteral(
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
XlaBuilder b(TestName());
auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}}));
auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}, {2}}}));
auto r3 = ConstantLiteral(
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
XlaBuilder b(TestName());
auto r1 =
ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
auto r3 = ConstantLiteral(
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
XlaBuilder b(TestName());
auto r1 =
ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
auto r3 = ConstantLiteral(
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
XlaBuilder b(TestName());
auto r1 = ConstantLiteral(
&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
&b, LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
auto r3 = ConstantLiteral(
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
XlaBuilder b(TestName());
auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}}}));
auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}}}));
auto r3 = ConstantLiteral(
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
struct R2ImplicitBroadcastSpec {
@ -618,7 +617,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
ComputeAndCompareLiteral(
&builder, *expected,
&builder, expected,
{r2_implicit_global_data1.get(), r2_global_data.get(),
r2_implicit_global_data2.get()},
ErrorSpec(1e-6, 1e-6));
@ -630,65 +629,63 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
XlaBuilder b(TestName());
auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}}));
auto r2 =
ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}}));
auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
XlaBuilder b(TestName());
auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1}, {2}}));
auto r2 =
ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1}, {2}}));
auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1, {0});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {1});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {2});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
@ -697,7 +694,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
auto r1_1 = ConstantR1<float>(&b, {100, 200});
auto r1_2 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
for (int i = 0; i < 3; ++i) {
r3 = Add(r1_0, r3, {0});
r3 = Add(r3, r1_1, {1});
@ -709,7 +706,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
{{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
{{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
@ -730,7 +727,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
{{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
{{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
@ -739,7 +736,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});

View File

@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0<float>(42.0),
*result, error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0<float>(42.0), result,
error_spec_));
}
XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
@ -63,7 +63,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
*LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), result,
error_spec_));
}
@ -86,12 +86,12 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
*LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
LiteralSlice(*result, {0}), error_spec_));
LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
LiteralSlice(result, {0}), error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(
*LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
LiteralSlice(*result, {1}), error_spec_));
LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
LiteralSlice(result, {1}), error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
@ -107,7 +107,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
*LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), result,
error_spec_));
}
@ -126,7 +126,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
*LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), result,
error_spec_));
}
@ -143,9 +143,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
*LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
{{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
*result, error_spec_));
LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
{{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
@ -166,9 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
Array2D<float> pz({{1, 2}, {1, 2}});
expected.FillWithPZ(pz);
EXPECT_TRUE(
LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
*result, error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(
LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
@ -197,9 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
}
expected.FillWithYX(yx);
EXPECT_TRUE(
LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
*result, error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(
LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
@ -220,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array),
*result, error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array),
result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
@ -240,9 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
Array4D<float> expected(64, 64, 3, 3);
expected.Fill(1.0f);
EXPECT_TRUE(
LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
*result, error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(
LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
@ -263,9 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
Array4D<float> expected(3, 3, 2, 2);
expected.FillWithYX(to_broadcast);
EXPECT_TRUE(
LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
*result, error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(
LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
@ -295,9 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(
LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
*result, error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(
LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
} // namespace

View File

@ -77,8 +77,7 @@ class CallOpTest : public ClientLibraryTestBase {
XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR0F32IdentityComputation();
auto constant =
ConstantLiteral(&builder, *LiteralUtil::CreateR0<float>(42.0));
auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0<float>(42.0));
Call(&builder, callee, {constant});
ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f));
@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR1S0F32AdditionComputation();
auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
auto y = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
Call(&builder, callee, {x, y});
ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f));
@ -98,9 +97,9 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR1S2F32AdditionComputation();
auto x =
ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
auto y =
ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
Call(&builder, callee, {x, y});
ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
@ -133,7 +132,7 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> start,
client_->TransferToServer(*LiteralUtil::CreateR0<float>(1.0f)));
client_->TransferToServer(LiteralUtil::CreateR0<float>(1.0f)));
ComputeAndCompareR0<float>(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f));
}
@ -141,10 +140,10 @@ XLA_TEST_F(CallOpTest, CallR0F32Tuple) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR0F32TupleComputation();
auto elem = LiteralUtil::CreateR0<float>(42.0);
auto tuple = LiteralUtil::MakeTuple({elem.get()});
Call(&builder, callee, {ConstantLiteral(&builder, *elem)});
auto tuple = LiteralUtil::MakeTuple({&elem});
Call(&builder, callee, {ConstantLiteral(&builder, elem)});
ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f));
ComputeAndCompareTuple(&builder, tuple, {}, ErrorSpec(0.01f));
}
} // namespace

View File

@ -38,14 +38,14 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) {
XlaBuilder builder("add_two_params");
auto param_literal = LiteralUtil::CreateR1<float>({1.1f, 2.2f});
auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0");
auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1");
auto p0 = Parameter(&builder, 0, param_literal.shape(), "param0");
auto p1 = Parameter(&builder, 1, param_literal.shape(), "param1");
Add(p0, p1);
auto param0_data =
client_->TransferToServer(*param_literal).ConsumeValueOrDie();
client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto param1_data =
client_->TransferToServer(*param_literal).ConsumeValueOrDie();
client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto computation_status = builder.Build();
ASSERT_IS_OK(computation_status.status());
@ -86,12 +86,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
auto computation = computation_status.ConsumeValueOrDie();
auto f32_literal = LiteralUtil::CreateR0<float>(1.1f);
auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie();
auto f32_data = client_->TransferToServer(f32_literal).ConsumeValueOrDie();
auto f32_4_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
auto f32_4_data =
client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie();
client_->TransferToServer(f32_4_literal).ConsumeValueOrDie();
auto u8_4_literal = LiteralUtil::CreateR1U8("hola");
auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie();
auto u8_4_data = client_->TransferToServer(u8_4_literal).ConsumeValueOrDie();
// Match
auto status = client_->Execute(

View File

@ -101,7 +101,7 @@ StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
return client_->Execute(computation, arguments, &execution_options_);
}
StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
ExecutionOptions execution_options = execution_options_;
@ -113,7 +113,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
&execution_options);
}
StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
// Build the computation, as a convenience.
@ -121,8 +121,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
}
StatusOr<std::unique_ptr<Literal>>
ClientLibraryTestBase::ExecuteAndTransferReference(
StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
ExecutionOptions execution_options = execution_options_;
@ -148,15 +147,15 @@ string ClientLibraryTestBase::ExecuteToString(
if (!result.ok()) {
return result.status().ToString();
} else {
return result.ValueOrDie()->ToString();
return result.ValueOrDie().ToString();
}
}
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
absl::Span<GlobalData* const> arguments) {
std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
Literal expected_literal = LiteralUtil::CreateR1(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@ -182,7 +181,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
const string& error_message)>& verify_output) {
// Try with no layout requirement.
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments));
verify_output(*actual, "");
verify_output(actual, "");
// Try with all output layouts.
std::vector<int64> minor_to_major(ShapeUtil::Rank(expected.shape()));
@ -193,7 +192,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
AsInt64Slice(expected.shape().dimensions()), minor_to_major);
TF_ASSIGN_OR_RETURN(auto actual,
ExecuteAndTransfer(computation, arguments, &layout));
verify_output(*actual,
verify_output(actual,
absl::StrCat("Test with output layout: ",
ShapeUtil::HumanStringWithLayout(layout)));
} while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
@ -218,9 +217,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
TF_ASSIGN_OR_RETURN(auto literal,
client_->Transfer(*arguments[index], nullptr));
// Skip tuples because they don't have a rank.
if (ShapeUtil::IsTuple(literal->shape())) {
if (ShapeUtil::IsTuple(literal.shape())) {
layout_strings.push_back(
ShapeUtil::HumanStringWithLayout(literal->shape()));
ShapeUtil::HumanStringWithLayout(literal.shape()));
arguments_with_layout.push_back(arguments[index]);
TF_RETURN_IF_ERROR(choose(index + 1));
arguments_with_layout.pop_back();
@ -228,15 +227,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
return Status::OK();
}
std::vector<int64> minor_to_major(ShapeUtil::Rank(literal->shape()));
std::vector<int64> minor_to_major(ShapeUtil::Rank(literal.shape()));
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
do {
auto literal_relayout =
literal->Relayout(LayoutUtil::MakeLayout(minor_to_major));
literal.Relayout(LayoutUtil::MakeLayout(minor_to_major));
layout_strings.push_back(
ShapeUtil::HumanStringWithLayout(literal_relayout->shape()));
ShapeUtil::HumanStringWithLayout(literal_relayout.shape()));
TF_ASSIGN_OR_RETURN(auto data,
client_->TransferToServer(*literal_relayout));
client_->TransferToServer(literal_relayout));
arguments_with_layout.push_back(data.get());
TF_RETURN_IF_ERROR(choose(index + 1));
arguments_with_layout.pop_back();
@ -256,7 +255,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
for (const auto& str : layout_strings) {
absl::StrAppend(&error_message, str, " ");
}
verify_output(*actual, error_message);
verify_output(actual, error_message);
return Status::OK();
};
@ -290,11 +289,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
const Literal* expected_ptr = &expected;
std::unique_ptr<Literal> converted_expected;
Literal converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get();
expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
@ -319,7 +318,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual));
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual));
return Status::OK();
}
@ -346,11 +345,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
const Literal* expected_ptr = &expected;
std::unique_ptr<Literal> converted_expected;
Literal converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get();
expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
@ -376,7 +375,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error));
EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error));
return Status::OK();
}
@ -391,12 +390,12 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
auto actual = actual_status.ConsumeValueOrDie();
// Turn the expected value into a literal.
std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected);
Literal expected_literal = LiteralUtil::CreateR1U8(expected);
VLOG(1) << "expected: " << expected_literal->ToString();
VLOG(1) << "actual: " << actual->ToString();
VLOG(1) << "expected: " << expected_literal.ToString();
VLOG(1) << "actual: " << actual.ToString();
EXPECT_EQ(expected, actual->GetR1U8AsString());
EXPECT_EQ(expected, actual.GetR1U8AsString());
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
@ -408,7 +407,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual));
EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
@ -420,7 +419,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error));
EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error));
}
void ClientLibraryTestBase::ComputeAndCompare(
@ -430,9 +429,9 @@ void ClientLibraryTestBase::ComputeAndCompare(
if (!status_or_data.ok()) {
return;
}
std::unique_ptr<Literal> reference, result;
Literal reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result));
EXPECT_TRUE(LiteralTestUtil::Equal(reference, result));
}
void ClientLibraryTestBase::ComputeAndCompare(
@ -442,12 +441,12 @@ void ClientLibraryTestBase::ComputeAndCompare(
if (!status_or_data.ok()) {
return;
}
std::unique_ptr<Literal> reference, result;
Literal reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error));
EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error));
}
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
StatusOr<std::pair<Literal, Literal>>
ClientLibraryTestBase::ComputeValueAndReference(
XlaBuilder* builder, absl::Span<const Literal> arguments) {
// Transfer the arguments to the executor service. We put the unique_ptr's
@ -569,8 +568,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
XlaBuilder* builder) {
return ConstantLiteral(builder, use_bfloat16_
? *LiteralUtil::ConvertF32ToBF16(literal)
: literal);
? LiteralUtil::ConvertF32ToBF16(literal)
: LiteralSlice(literal));
}
std::unique_ptr<GlobalData>
@ -600,7 +599,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
const Literal& literal) {
if (use_bfloat16_) {
return std::move(*LiteralUtil::ConvertF32ToBF16(literal));
return LiteralUtil::ConvertF32ToBF16(literal);
}
return literal.Clone();
}

View File

@ -95,11 +95,11 @@ class ClientLibraryTestBase : public ::testing::Test {
StatusOr<std::unique_ptr<GlobalData>> Execute(
XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
StatusOr<Literal> ExecuteAndTransfer(
XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
StatusOr<Literal> ExecuteAndTransfer(
const XlaComputation& computation,
absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
@ -107,7 +107,7 @@ class ClientLibraryTestBase : public ::testing::Test {
// This executes the computation via the reference client (which connects a
// interpreter backend). The result is used as the expected values of the
// computation.
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransferReference(
StatusOr<Literal> ExecuteAndTransferReference(
const XlaComputation& computation,
absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
@ -282,7 +282,7 @@ class ClientLibraryTestBase : public ::testing::Test {
template <class T>
XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
return AddParam(*LiteralUtil::CreateFromArray(argument), builder);
return AddParam(LiteralUtil::CreateFromArray(argument), builder);
}
// Creates a constant instruction with the given literal. When the
@ -297,14 +297,14 @@ class ClientLibraryTestBase : public ::testing::Test {
template <typename NativeT>
XlaOp CreateConstantFromArray(const Array<NativeT>& array,
XlaBuilder* builder) {
return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array),
return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array),
builder);
}
// Same as CreateConstantFromArray, but for scalars.
template <typename NativeT>
XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
return CreateConstantFromLiteral(*LiteralUtil::CreateR0<NativeT>(value),
return CreateConstantFromLiteral(LiteralUtil::CreateR0<NativeT>(value),
builder);
}
@ -375,9 +375,8 @@ class ClientLibraryTestBase : public ::testing::Test {
// Executes the computation and calculates the expected reference value using
// the reference client. Returns two literals in the order of (expected,
// actual).
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
ComputeValueAndReference(XlaBuilder* builder,
absl::Span<const Literal> arguments);
StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
XlaBuilder* builder, absl::Span<const Literal> arguments);
Client* client_;
Client* ref_client_; // To compute reference result.
@ -412,9 +411,8 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR0(
XlaBuilder* builder, NativeT expected,
absl::Span<GlobalData* const> arguments) {
std::unique_ptr<Literal> expected_literal =
LiteralUtil::CreateR0<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@ -428,9 +426,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
LiteralUtil::CreateR0<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@ -438,9 +435,8 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, absl::Span<const NativeT> expected,
absl::Span<GlobalData* const> arguments) {
std::unique_ptr<Literal> expected_literal =
LiteralUtil::CreateR1<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@ -454,9 +450,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
LiteralUtil::CreateR1<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@ -464,9 +459,9 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR2(
XlaBuilder* builder, const Array2D<NativeT>& expected,
absl::Span<GlobalData* const> arguments) {
std::unique_ptr<Literal> expected_literal =
Literal expected_literal =
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@ -480,9 +475,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
Literal expected_literal =
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@ -490,9 +485,9 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR3(
XlaBuilder* builder, const Array3D<NativeT>& expected,
absl::Span<GlobalData* const> arguments) {
std::unique_ptr<Literal> expected_literal =
Literal expected_literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@ -506,9 +501,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
Literal expected_literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@ -516,9 +511,9 @@ template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR4(
XlaBuilder* builder, const Array4D<NativeT>& expected,
absl::Span<GlobalData* const> arguments) {
std::unique_ptr<Literal> expected_literal =
Literal expected_literal =
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
@ -532,9 +527,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
Literal expected_literal =
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@ -542,13 +537,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
NativeT value, int64 parameter_number, const string& name,
XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = LiteralUtil::CreateR0(value);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(*literal);
Literal literal = LiteralUtil::CreateR0(value);
if (use_bfloat16_ && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
*data_handle = Parameter(builder, parameter_number, literal->shape(), name);
client_->TransferToServer(literal).ConsumeValueOrDie();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@ -556,13 +551,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
absl::Span<const NativeT> values, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = LiteralUtil::CreateR1(values);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(*literal);
Literal literal = LiteralUtil::CreateR1(values);
if (use_bfloat16_ && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
*data_handle = Parameter(builder, parameter_number, literal->shape(), name);
client_->TransferToServer(literal).ConsumeValueOrDie();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@ -570,13 +565,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
const Array2D<NativeT>& array_2d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = LiteralUtil::CreateR2FromArray2D(array_2d);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(*literal);
Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
if (use_bfloat16_ && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
*data_handle = Parameter(builder, parameter_number, literal->shape(), name);
client_->TransferToServer(literal).ConsumeValueOrDie();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@ -584,13 +579,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
const Array3D<NativeT>& array_3d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(array_3d);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(*literal);
Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
if (use_bfloat16_ && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
*data_handle = Parameter(builder, parameter_number, literal->shape(), name);
client_->TransferToServer(literal).ConsumeValueOrDie();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}

View File

@ -55,16 +55,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
std::unique_ptr<GlobalData> data,
client_->Execute(computation, {}, &execution_options));
std::unique_ptr<Literal> expected_literal =
LiteralUtil::CreateR2WithLayout<int32>(
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
TF_ASSERT_OK_AND_ASSIGN(
auto computed, client_->Transfer(*data, &expected_literal->shape()));
auto computed, client_->Transfer(*data, &expected_literal.shape()));
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
expected_literal->shape(), computed->shape()));
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
expected_literal.shape(), computed.shape()));
EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
}
@ -91,19 +90,19 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
auto result,
client_->ExecuteAndTransfer(computation, {}, &execution_options));
LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
LiteralSlice(*result, {0}));
LiteralSlice(result, {0}));
LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
LiteralSlice(*result, {1}));
LiteralSlice(result, {1}));
EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
EXPECT_TRUE(ShapeUtil::IsTuple(result.shape()));
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape()));
EXPECT_TRUE(ShapeUtil::Equal(
ShapeUtil::GetTupleElementShape(result->shape(), 0),
ShapeUtil::GetTupleElementShape(result.shape(), 0),
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
/*minor_to_major=*/{0, 1})));
EXPECT_TRUE(ShapeUtil::Equal(
ShapeUtil::GetTupleElementShape(result->shape(), 1),
ShapeUtil::GetTupleElementShape(result.shape(), 1),
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
/*minor_to_major=*/{1, 0})));
}
@ -114,7 +113,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> const_arg,
client_->TransferToServer(
*LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
XlaBuilder b(TestName() + ".add");
Add(Parameter(&b, 0, shape, "param_0"),
@ -140,9 +139,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
TF_ASSERT_OK_AND_ASSIGN(
auto result_literal,
client_->Transfer(*results[0], &expected_result->shape()));
client_->Transfer(*results[0], &expected_result.shape()));
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal));
EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal));
}
} // namespace

View File

@ -42,14 +42,14 @@ class CompilationCacheTest : public ClientLibraryTestBase {
absl::Span<GlobalData* const> arguments,
float expected_result, bool expect_cache_hit) {
ExecutionProfile execution_profile;
std::unique_ptr<Literal> result =
Literal result =
client_
->ExecuteAndTransfer(computation, arguments,
/*execution_options=*/&execution_options_,
&execution_profile)
.ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(
*LiteralUtil::CreateR0<float>(expected_result), *result, error_spec_));
LiteralUtil::CreateR0<float>(expected_result), result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
@ -63,10 +63,9 @@ class CompilationCacheTest : public ClientLibraryTestBase {
->Execute(computation, arguments,
&execution_options_, &execution_profile)
.ConsumeValueOrDie();
std::unique_ptr<Literal> result =
client_->Transfer(*data_handle).ConsumeValueOrDie();
Literal result = client_->Transfer(*data_handle).ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(
*LiteralUtil::CreateR2<float>(expected_result), *result, error_spec_));
LiteralUtil::CreateR2<float>(expected_result), result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
@ -88,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) {
XLA_TEST_F(CompilationCacheTest,
DISABLED_ComputationCalledWithDifferentParameters) {
std::unique_ptr<GlobalData> data_42 =
client_->TransferToServer(*LiteralUtil::CreateR0<float>(42.0f))
client_->TransferToServer(LiteralUtil::CreateR0<float>(42.0f))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> data_123 =
client_->TransferToServer(*LiteralUtil::CreateR0<float>(123.0f))
client_->TransferToServer(LiteralUtil::CreateR0<float>(123.0f))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> data_456 =
client_->TransferToServer(*LiteralUtil::CreateR0<float>(456.0f))
client_->TransferToServer(LiteralUtil::CreateR0<float>(456.0f))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
@ -145,12 +144,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) {
auto rowmaj_array = LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
auto rowmaj_handle =
client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
client_->TransferToServer(rowmaj_array).ConsumeValueOrDie();
auto colmaj_array = LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
auto colmaj_handle =
client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
client_->TransferToServer(colmaj_array).ConsumeValueOrDie();
XlaBuilder builder(TestName());
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");

View File

@ -69,9 +69,9 @@ class ComputeConstantTest : public ::testing::Test {
LOG(FATAL) << "invalid client_type value";
}
StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
Client* client, const XlaOp& operand, XlaBuilder* builder,
Layout* output_layout = nullptr) {
StatusOr<Literal> ComputeConstantLiteral(Client* client, const XlaOp& operand,
XlaBuilder* builder,
Layout* output_layout = nullptr) {
TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand));
TF_ASSIGN_OR_RETURN(auto computed,
client->ComputeConstant(subgraph, output_layout));
@ -83,7 +83,7 @@ class ComputeConstantTest : public ::testing::Test {
XlaBuilder* builder) {
TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand,
builder, nullptr));
return literal->Get<Scalar>({});
return literal.Get<Scalar>({});
}
bool IsConstant(const XlaOp& operand, XlaBuilder* builder) {
@ -206,9 +206,8 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
std::unique_ptr<Literal> expected_literal =
LiteralUtil::CreateR1<int32>({4, 6});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
Literal expected_literal = LiteralUtil::CreateR1<int32>({4, 6});
EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
@ -221,8 +220,8 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
Literal expected_literal = LiteralUtil::CreateR0<int32>(5);
EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
@ -241,12 +240,11 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
ConstantR2<int32>(&b, {{10, 20}, {30, 40}})),
&b, &layout_proto));
std::unique_ptr<Literal> expected_literal =
LiteralUtil::CreateR2WithLayout<int32>(
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
expected_literal->shape(), computed->shape()));
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
expected_literal.shape(), computed.shape()));
EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
}

View File

@ -536,8 +536,8 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
auto x_literal = LiteralUtil::CreateR0<float>(2.f);
auto y_literal = LiteralUtil::CreateR0<float>(3.f);
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
auto x = Parameter(&builder, 0, f32_scalar, "x");
@ -559,12 +559,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
auto x = Parameter(&builder, 0, x_literal->shape(), "x");
auto x = Parameter(&builder, 0, x_literal.shape(), "x");
auto y = Parameter(&builder, 1, f32_scalar, "y");
auto z = Parameter(&builder, 2, f32_scalar, "z");
auto bcast = Broadcast(y, {5});
@ -587,12 +587,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
auto x = Parameter(&builder, 0, x_literal->shape(), "x");
auto x = Parameter(&builder, 0, x_literal.shape(), "x");
auto y = Parameter(&builder, 1, f32_scalar, "y");
auto z = Parameter(&builder, 2, f32_scalar, "y");
auto y_bcast = Broadcast(y, {1, 5, 7});

View File

@ -359,8 +359,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
ComputeAndCompareTuple(
&builder,
*LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12.0f).get(),
LiteralUtil::CreateR0<float>(25.0f).get()}),
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<float>(12.0f),
LiteralUtil::CreateR0<float>(25.0f)}),
{pred_arg.get()}, error_spec_);
}
@ -375,12 +375,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
CreateR1TupleFloorComputation());
ComputeAndCompareTuple(
&builder,
*LiteralUtil::MakeTuple(
{LiteralUtil::CreateR1<float>({13.0f, 16.0f}).get(),
LiteralUtil::CreateR1<float>({26.0f, 30.0f}).get()}),
{pred_arg.get()}, error_spec_);
ComputeAndCompareTuple(&builder,
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<float>({13.0f, 16.0f}),
LiteralUtil::CreateR1<float>({26.0f, 30.0f})}),
{pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a tuple of a predicate, a
@ -415,13 +414,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
false_builder_result.ConsumeValueOrDie());
ComputeAndCompareTuple(
&builder,
*LiteralUtil::MakeTuple(
{LiteralUtil::CreateR0<bool>(true).get(),
LiteralUtil::CreateR0<float>(12.2f).get(),
LiteralUtil::CreateR1<float>({12.8f, 14.6f}).get()}),
{pred_arg.get()}, error_spec_);
ComputeAndCompareTuple(&builder,
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<bool>(true),
LiteralUtil::CreateR0<float>(12.2f),
LiteralUtil::CreateR1<float>({12.8f, 14.6f})}),
{pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a nested tuple.
@ -463,15 +461,13 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
ComputeAndCompareTuple(
&builder,
*LiteralUtil::MakeTuple(
{LiteralUtil::MakeTuple(
{LiteralUtil::CreateR0<float>(46.6f).get(),
LiteralUtil::CreateR1<float>({54.4f, 58.4f}).get()})
.get(),
LiteralUtil::MakeTuple(
{LiteralUtil::CreateR1<float>({62.1f, 67.4f}).get(),
LiteralUtil::CreateR0<float>(9.3f).get()})
.get()}),
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<float>(46.6f),
LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
LiteralUtil::CreateR0<float>(9.3f)})}),
{pred_arg.get()}, error_spec_);
}
@ -633,8 +629,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
ComputeAndCompareTuple(
&builder,
*LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(a).get(),
LiteralUtil::CreateR0<float>(b).get()}),
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
{x_arg.get(), y_arg.get()}, error_spec_);
};
@ -669,10 +665,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
{
// Pred is true case.
std::vector<Literal> args;
args.push_back(std::move(
*LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
LiteralUtil::CreateR0<int32>(-42).get()})));
args.push_back(std::move(*LiteralUtil::CreateR0<bool>(true)));
args.push_back(
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
LiteralUtil::CreateR0<int32>(-42)}));
args.push_back(LiteralUtil::CreateR0<bool>(true));
XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
@ -682,10 +678,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
{
// Pred is false case.
std::vector<Literal> args;
args.push_back(std::move(
*LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
LiteralUtil::CreateR0<int32>(-42).get()})));
args.push_back(std::move(*LiteralUtil::CreateR0<bool>(false)));
args.push_back(
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
LiteralUtil::CreateR0<int32>(-42)}));
args.push_back(LiteralUtil::CreateR0<bool>(false));
XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");

View File

@ -110,7 +110,7 @@ TEST_F(ConstantsTest, Small_2x2) {
TEST_F(ConstantsTest, Empty_3x0x2) {
XlaBuilder builder(TestName());
ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(
ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(
Array3D<float>(3, 0, 2)));
ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) {
{{5.f, 6.f}, // y0
{7.f, 8.f}}, // y1
});
ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(array3d));
ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(array3d));
ComputeAndCompareR3<float>(&builder, array3d, {});
}
@ -140,12 +140,11 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
{5.0f, 4.4f}, // p2
});
input_array.FillWithPZ(pz);
std::unique_ptr<Literal> input_literal =
LiteralUtil::CreateR4FromArray4D(input_array);
Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array);
{
XlaBuilder builder(TestName());
ConstantLiteral(&builder, *input_literal);
ConstantLiteral(&builder, input_literal);
ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
}
@ -159,23 +158,21 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
// TODO(b/29263943): Support tuple constants.
TEST_F(ConstantsTest, DISABLED_TupleConstant) {
XlaBuilder builder(TestName());
ConstantLiteral(&builder,
*LiteralUtil::MakeTuple(
{LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
LiteralUtil::CreateR1<float>({2.0, 42}).get()}));
ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
LiteralUtil::CreateR1<float>({2.0, 42})}));
std::unique_ptr<Literal> result =
ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
LiteralSlice(*result, {0}), error_spec_);
LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(*result, {1}),
LiteralSlice(result, {0}), error_spec_);
LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(result, {1}),
error_spec_);
}
TEST_F(ConstantsTest, Token) {
XlaBuilder builder(TestName());
ConstantLiteral(&builder, *LiteralUtil::CreateToken());
ConstantLiteral(&builder, LiteralUtil::CreateToken());
// TODO(b/80000000): tokens cannot be returned from computations.
Tuple(&builder, {});
TF_ASSERT_OK(Execute(&builder, {}).status());

View File

@ -210,10 +210,10 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
static_cast<int64>(0x8000008000000000LL),
static_cast<int64>(0x8000010000000000LL),
};
std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int64>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
Literal arg_literal = LiteralUtil::CreateR1<int64>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, F32);
@ -229,10 +229,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) {
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff,
0x80000000, 0x80000001, 0x80000002, 0x80000003,
0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF};
std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, F32);
@ -247,10 +247,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XlaBuilder builder(TestName());
std::vector<float> arg{0.0f, 1.0f, 16777216.0f,
16777218.0f, 2147483647.0f, 4294967040.0f};
std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, U32);
@ -264,10 +264,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF};
std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@ -281,10 +281,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000};
std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
Literal arg_literal = LiteralUtil::CreateR1<int32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@ -318,10 +318,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
9223370937343148032.f,
-9223371487098961920.f,
-9223370937343148032.f};
std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@ -456,7 +456,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
client_->TransferToServer(*LiteralUtil::CreateR1<half>(input)));
client_->TransferToServer(LiteralUtil::CreateR1<half>(input)));
XlaBuilder builder(TestName());
ConvertElementType(
@ -476,7 +476,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
client_->TransferToServer(*LiteralUtil::CreateR1<float>(input)));
client_->TransferToServer(LiteralUtil::CreateR1<float>(input)));
XlaBuilder builder(TestName());
ConvertElementType(

View File

@ -93,8 +93,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest,
auto weight_array = absl::make_unique<Array4D<float>>(4, 3, 1, 1);
weight_array->FillWithMultiples(0.2);
auto weight_data =
client_
->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array))
client_->TransferToServer(LiteralUtil::CreateR4FromArray4D(*weight_array))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());

View File

@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
{std::move(*LiteralUtil::CreateFromArray(input_data)),
std::move(*LiteralUtil::CreateFromArray(filter_data))},
{LiteralUtil::CreateFromArray(input_data),
LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
{7.0f, 8.0f},
}));
ComputeAndCompare(&builder,
{std::move(*LiteralUtil::CreateFromArray(input_data)),
std::move(*LiteralUtil::CreateFromArray(filter_data))},
{LiteralUtil::CreateFromArray(input_data),
LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
{std::move(*LiteralUtil::CreateFromArray(input_data)),
std::move(*LiteralUtil::CreateFromArray(filter_data))},
{LiteralUtil::CreateFromArray(input_data),
LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
{{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
// clang-format on
ComputeAndCompare(&builder,
{std::move(*LiteralUtil::CreateFromArray(input_data)),
std::move(*LiteralUtil::CreateFromArray(filter_data))},
{LiteralUtil::CreateFromArray(input_data),
LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
Array3D<float> expected({{{510, 610, 710, 810}}});
auto input_literal =
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
auto input_literal =
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
auto input_literal =
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
auto input_literal =
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
{{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
auto input_literal =
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@ -435,23 +435,23 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
iota(input_elems.begin(), input_elems.end(), 1.0f);
auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota(filter_elems.begin(), filter_elems.end(), 1.0f);
auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<float>(
{19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie();
auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*filter_r5).ConsumeValueOrDie();
client_->TransferToServer(filter_r5).ConsumeValueOrDie();
ComputeAndCompareLiteral(&builder, *expected_r5,
ComputeAndCompareLiteral(&builder, expected_r5,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@ -498,23 +498,23 @@ class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
auto input_literal =
client_->TransferToServer(*input_r4).ConsumeValueOrDie();
client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
client_->TransferToServer(filter_r4).ConsumeValueOrDie();
ComputeAndCompareLiteral(&builder, *expected_r4,
ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@ -558,12 +558,12 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(16029), static_cast<T>(16218), static_cast<T>(16407),
@ -571,14 +571,14 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
static_cast<T>(18369), static_cast<T>(18576), static_cast<T>(18783),
static_cast<T>(19620), static_cast<T>(19836), static_cast<T>(20052),
static_cast<T>(20925), static_cast<T>(21150), static_cast<T>(21375)});
auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
auto input_literal =
client_->TransferToServer(*input_r4).ConsumeValueOrDie();
client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
client_->TransferToServer(filter_r4).ConsumeValueOrDie();
ComputeAndCompareLiteral(&builder, *expected_r4,
ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@ -624,26 +624,26 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(5076), static_cast<T>(5160), static_cast<T>(5244),
static_cast<T>(5328), static_cast<T>(6164), static_cast<T>(6264),
static_cast<T>(6364), static_cast<T>(6464), static_cast<T>(7380),
static_cast<T>(7496), static_cast<T>(7612), static_cast<T>(7728)});
auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
auto input_literal =
client_->TransferToServer(*input_r4).ConsumeValueOrDie();
client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
client_->TransferToServer(filter_r4).ConsumeValueOrDie();
ComputeAndCompareLiteral(&builder, *expected_r4,
ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@ -692,8 +692,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
expected_result.Fill(0);
ComputeAndCompare(&builder,
{std::move(*LiteralUtil::CreateFromArray(param0)),
std::move(*LiteralUtil::CreateFromArray(param1))},
{LiteralUtil::CreateFromArray(param0),
LiteralUtil::CreateFromArray(param1)},
error_spec_);
}
@ -749,26 +749,25 @@ class Convolve1D1WindowTestBase
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
static_cast<T>(1.0f));
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
static_cast<T>(1.0f));
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
std::vector<T> expect_elems(batch * output_feature * num_windows,
static_cast<T>(window_size * input_feature));
auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
auto expected_r3 =
expected_r1->Reshape({batch, num_windows, output_feature})
.ConsumeValueOrDie();
auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
.ConsumeValueOrDie();
auto input_literal =
client_->TransferToServer(*input_r3).ConsumeValueOrDie();
client_->TransferToServer(input_r3).ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
ComputeAndCompareLiteral(&builder, *expected_r3,
client_->TransferToServer(filter_r3).ConsumeValueOrDie();
ComputeAndCompareLiteral(&builder, expected_r3,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@ -868,8 +867,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
}));
ComputeAndCompare(&builder,
{std::move(*LiteralUtil::CreateFromArray(input_data)),
std::move(*LiteralUtil::CreateFromArray(filter_data))},
{LiteralUtil::CreateFromArray(input_data),
LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
@ -891,9 +890,8 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
Array4D<float> filter_data(1, 1, 1, 2);
filter_data.FillIota(10);
ComputeAndCompare(&builder,
{std::move(*LiteralUtil::CreateFromArray(input_data)),
std::move(*LiteralUtil::CreateFromArray(filter_data))});
ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data),
LiteralUtil::CreateFromArray(filter_data)});
}
XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
@ -928,8 +926,7 @@ XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
/*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums,
/*feature_group_count=*/64);
ComputeAndCompare(&builder,
{std::move(*LiteralUtil::CreateFromArray(input_data))},
ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)},
error_spec_);
}

View File

@ -1335,23 +1335,23 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
auto gradients_flat = LiteralUtil::CreateR1<float>({1});
auto gradients_literal =
gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
auto gradients = ConstantLiteral(&builder, *gradients_literal);
gradients_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
auto gradients = ConstantLiteral(&builder, gradients_literal);
auto weights_flat = LiteralUtil::CreateR1<float>({1, 10, 100});
auto weights_literal =
weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
auto weights = ConstantLiteral(&builder, *weights_literal);
weights_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
auto weights = ConstantLiteral(&builder, weights_literal);
auto expected_flat = LiteralUtil::CreateR1<float>({10});
auto expected_literal =
expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
expected_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
auto mirrored_weights = Rev(weights, {2, 3, 4});
ConvWithGeneralPadding(gradients, mirrored_weights,
/*window_strides=*/{1, 1, 1},
/*padding=*/{{0, 0}, {0, 0}, {1, 1}});
ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
@ -1359,17 +1359,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
auto activations_flat = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
auto activations_literal =
activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
auto activations = ConstantLiteral(&builder, *activations_literal);
activations_flat.Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
auto activations = ConstantLiteral(&builder, activations_literal);
auto gradients_flat = LiteralUtil::CreateR1<float>({100, 10, 1});
auto gradients_literal =
gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
auto gradients = ConstantLiteral(&builder, *gradients_literal);
gradients_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
auto gradients = ConstantLiteral(&builder, gradients_literal);
auto expected_flat = LiteralUtil::CreateR1<float>({13, 24, 130});
auto expected_literal =
expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
expected_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
auto forward_conv =
ConvGeneralDilated(activations, gradients,
@ -1379,7 +1379,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
XlaBuilder::CreateDefaultConvDimensionNumbers(
/*num_spatial_dims=*/3));
Transpose(forward_conv, {0, 1, 2, 3, 4});
ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
}
} // namespace

View File

@ -40,16 +40,16 @@ class CopyOpTest : public HloTestBase {
protected:
void TestCopyOp(const Literal& literal) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(literal.CloneToUnique()));
auto constant =
builder.AddInstruction(HloInstruction::CreateConstant(literal.Clone()));
builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kCopy, constant));
auto computation = builder.Build();
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result));
Literal result = ExecuteAndTransfer(std::move(module), {});
EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
@ -58,31 +58,30 @@ class CopyOpTest : public HloTestBase {
};
XLA_TEST_F(CopyOpTest, CopyR0Bool) {
TestCopyOp(*LiteralUtil::CreateR0<bool>(true));
TestCopyOp(LiteralUtil::CreateR0<bool>(true));
}
XLA_TEST_F(CopyOpTest, CopyR1S0U32) {
TestCopyOp(*LiteralUtil::CreateR1<uint32>({}));
TestCopyOp(LiteralUtil::CreateR1<uint32>({}));
}
XLA_TEST_F(CopyOpTest, CopyR1S3U32) {
TestCopyOp(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
TestCopyOp(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
TestCopyOp(
*LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) {
TestCopyOp(*LiteralUtil::CreateR4(
TestCopyOp(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) {
TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
TestCopyOp(LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
}
XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
@ -90,7 +89,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
// Copy literal to device to use as parameter.
auto literal = LiteralUtil::CreateR0<float>(42.0);
Shape shape = literal->shape();
Shape shape = literal.shape();
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0"));
@ -102,9 +101,8 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
std::unique_ptr<Literal> result =
ExecuteAndTransfer(std::move(module), {literal.get()});
LiteralTestUtil::ExpectR0Near<float>(42.0f, *result, error_spec_);
Literal result = ExecuteAndTransfer(std::move(module), {&literal});
LiteralTestUtil::ExpectR0Near<float>(42.0f, result, error_spec_);
}
XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
@ -123,19 +121,17 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, *result,
Literal result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, result,
error_spec_);
}
XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
HloComputation::Builder builder(TestName());
std::unique_ptr<Literal> literal =
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
// Reverse the minor-to-major order of the literal.
Layout* literal_layout =
literal->mutable_shape_do_not_use()->mutable_layout();
Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout();
ASSERT_EQ(2, literal_layout->minor_to_major_size());
literal_layout->mutable_minor_to_major()->SwapElements(0, 1);
@ -149,11 +145,11 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
Literal result = ExecuteAndTransfer(std::move(module), {});
// The result of the computation has the default layout, which is the inverse
// of the layout of the source literal.
LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, *result,
LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, result,
error_spec_);
}
@ -169,7 +165,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
HloComputation::Builder builder(TestName());
std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(a);
Literal literal = LiteralUtil::CreateR3FromArray3D(a);
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@ -182,9 +178,9 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0}));
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
Literal result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR3EqualArray3D(a, *result);
LiteralTestUtil::ExpectR3EqualArray3D(a, result);
}
void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
@ -203,7 +199,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
HloComputation::Builder builder(TestName());
std::unique_ptr<Literal> literal = LiteralUtil::CreateR4FromArray4D(a);
Literal literal = LiteralUtil::CreateR4FromArray4D(a);
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@ -216,9 +212,9 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation));
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
Literal result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR4EqualArray4D(a, *result);
LiteralTestUtil::ExpectR4EqualArray4D(a, result);
}
XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) {
@ -250,11 +246,11 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) {
XlaBuilder builder(TestName());
Parameter(&builder, 0, in_shape, "input");
auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie();
auto input_data = client_->TransferToServer(empty).ConsumeValueOrDie();
auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
.ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual));
EXPECT_TRUE(LiteralTestUtil::Equal(empty, actual));
}
} // namespace

View File

@ -46,7 +46,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
auto module =
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()}));
EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
}
XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
@ -68,9 +68,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
EXPECT_EQ(
*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
*ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()}));
EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
ExecuteAndTransfer(std::move(module), {&literal0, &literal1}));
}
// On the GPU backend, constants get special handling. Someone might pass a
@ -95,8 +94,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
*ExecuteAndTransfer(std::move(module), {literal0.get()}));
EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
ExecuteAndTransfer(std::move(module), {&literal0}));
}
} // namespace

View File

@ -80,8 +80,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
module->AddEntryComputation(builder.Build());
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR0Near<float>(44.0f, *result, error_spec_);
Literal result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
}
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
@ -101,8 +101,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
module->AddEntryComputation(builder.Build());
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR0Near<float>(10.0f, *result, error_spec_);
Literal result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
}
XLA_TEST_F(CustomCallTest,
@ -125,9 +125,9 @@ XLA_TEST_F(CustomCallTest,
module->AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
Literal result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR3EqualArray3D<float>(
Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result);
Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
}
class CustomCallClientAPITest : public ClientLibraryTestBase {};

View File

@ -64,11 +64,11 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) {
// Try copying the elements back and comparing it
auto handles = result_status.ConsumeValueOrDie();
std::unique_ptr<Literal> literal;
Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
@ -86,19 +86,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
auto handles1 = result_status1.ConsumeValueOrDie();
auto handles2 = result_status2.ConsumeValueOrDie();
std::unique_ptr<Literal> literal;
Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0]));
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1]));
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
handles1[0].reset();
handles1[1].reset();
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0]));
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1]));
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
}
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
@ -116,15 +116,15 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
// the same as handle[3] and handle[1] should be the same as handle[2].
auto handles = result_status.ConsumeValueOrDie();
std::unique_ptr<Literal> literal;
Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3]));
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
@ -142,19 +142,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
// should not have been deallocated because of reference counting.
global_data.reset();
std::unique_ptr<Literal> literal;
Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
/// Try deallocating one of the repeated elements, then copy
handles[0].reset();
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
@ -170,10 +170,9 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
LiteralUtil::CreateR1<float>({3.14f, -100.25f});
Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
client_->TransferToServer(param0_literal).ConsumeValueOrDie();
auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
Tuple(&builder, {p});
auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()});

View File

@ -68,16 +68,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
XlaOp param;
auto param_data = CreateParameterAndTransferLiteral(
0,
*LiteralUtil::MakeTuple(
{LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}).get(),
LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}})}),
"arg0", &builder, &param);
auto lhs = GetTupleElement(param, 0);
auto rhs = GetTupleElement(param, 1);
Dot(lhs, rhs);
ComputeAndCompareLiteral(&builder,
*LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
{param_data.get()});
}
@ -196,11 +196,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) {
auto lhs_handle =
this->client_
->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
.ConsumeValueOrDie();
auto rhs_handle = this->client_
->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
.ConsumeValueOrDie();
@ -219,14 +219,14 @@ class SquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f}, {3.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@ -286,24 +286,23 @@ void ParametricDotTest::TestImpl() {
std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
std::unique_ptr<Literal> dot_lhs_lit =
LiteralUtil::CreateR2FromArray2DWithLayout(
*dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(
param.dot_lhs_row_major)));
Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
*dot_lhs_data, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
std::unique_ptr<GlobalData> dot_lhs_handle =
client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie();
client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie();
std::unique_ptr<Array2D<NativeT>> dot_rhs_data =
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.k, param.n);
Layout rhs_layout = LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
std::unique_ptr<Literal> dot_rhs_lit =
Literal dot_rhs_lit =
LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
std::unique_ptr<GlobalData> dot_rhs_handle =
client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie();
client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie();
std::unique_ptr<Array2D<NativeT>> addend_data;
std::unique_ptr<Literal> addend_lit;
Literal addend_lit;
std::unique_ptr<GlobalData> addend_handle;
if (param.has_addend) {
@ -311,7 +310,7 @@ void ParametricDotTest::TestImpl() {
addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
*addend_data, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.addend_row_major)));
addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie();
addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie();
}
XlaBuilder builder(TestName());
@ -477,14 +476,14 @@ class NonsquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@ -511,12 +510,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); }
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
auto lhs_handle =
client_
->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
auto rhs_handle =
client_
->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
@ -584,7 +583,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
auto x_data = this->client_
->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
{{2000.0f, 200.0f}, {20.0f, 2.0f}}},
{{{3000.0f, 300.0f}, {30.0f, 3.0f}},
@ -592,7 +591,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
.ConsumeValueOrDie();
auto y_data =
this->client_
->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{11.0f, 22.0f}, {33.0f, 44.0f}},
{{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
@ -630,13 +629,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) {
auto x_data =
this->client_
->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
.ConsumeValueOrDie();
auto y_data =
this->client_
->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
.ConsumeValueOrDie();
@ -668,7 +667,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
auto x_data =
this->client_
->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{9.0f, 10.0f}, {11.0f, 12.0f}},
{{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
@ -676,7 +675,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
auto y_data =
this->client_
->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
{{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
.ConsumeValueOrDie();
@ -708,14 +707,14 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) {
auto lhs_handle =
this->client_
->TransferToServer(
*LiteralUtil::CreateR2FromArray2DWithLayout<T>(
LiteralUtil::CreateR2FromArray2DWithLayout<T>(
*lhs, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
this->client_
->TransferToServer(
*LiteralUtil::CreateR2FromArray2DWithLayout<T>(
LiteralUtil::CreateR2FromArray2DWithLayout<T>(
*rhs, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
@ -778,15 +777,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
*LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
*LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
*LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
this->template ComputeAndCompareR2<T>(
@ -827,15 +826,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
*LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
*LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
*LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
this->template ComputeAndCompareR2<T>(

View File

@ -124,13 +124,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
// vector<bool> is special so that it cannot be a Span<bool>, which
// is what the code below wants. So instead we do this.
Literal input_values =
std::move(*LiteralUtil::CreateR1(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
LiteralUtil::CreateR1(input_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie();
Literal expected_values =
std::move(*LiteralUtil::CreateR1(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR1(expected_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@ -150,13 +150,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array2D<int>& expected_values_int) {
Literal input_values =
std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@ -176,13 +176,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array3D<int>& expected_values_int) {
Literal input_values =
std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@ -359,17 +359,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
void RunR0(int input_value_int, int update_value_int,
const std::vector<IndexT> slice_starts, int expected_value_int) {
Literal input_value =
std::move(*LiteralUtil::CreateR0(input_value_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR0(input_value_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_value =
std::move(*LiteralUtil::CreateR0(update_value_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR0(update_value_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_value =
std::move(*LiteralUtil::CreateR0(expected_value_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR0(expected_value_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@ -390,17 +390,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
absl::Span<const int> expected_values_int) {
Literal input_values =
std::move(*LiteralUtil::CreateR1(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR1(input_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
std::move(*LiteralUtil::CreateR1(update_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR1(update_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
std::move(*LiteralUtil::CreateR1(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR1(expected_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@ -421,17 +421,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array2D<int>& expected_values_int) {
Literal input_values =
std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR2FromArray2D(update_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@ -452,17 +452,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array3D<int>& expected_values_int) {
Literal input_values =
std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR3FromArray3D(update_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
.Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@ -529,9 +529,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
template <typename NativeT>
void DumpArray(const string& name, const Array3D<NativeT> values) {
std::unique_ptr<Literal> literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(values);
LOG(INFO) << name << ":" << literal->ToString();
Literal literal = LiteralUtil::CreateR3FromArray3D<NativeT>(values);
LOG(INFO) << name << ":" << literal.ToString();
}
};
@ -719,7 +718,7 @@ void BM_DynamicSlice(int num_iters) {
auto input_literal = LiteralUtil::CreateR4(
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
auto input = ConstantLiteral(&builder, *input_literal);
auto input = ConstantLiteral(&builder, input_literal);
// Create dynamic slice start indices as a parameter: shape [4]
auto start_indices_shape = ShapeUtil::MakeShape(S32, {4});
@ -740,7 +739,7 @@ void BM_DynamicSlice(int num_iters) {
auto stream =
client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
stream.get(), *start_indices_literal, buffer));
stream.get(), start_indices_literal, buffer));
std::unique_ptr<LocalExecutable> executable =
client

View File

@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> input,
client_->TransferToServer(
*LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
XlaBuilder b(TestName() + ".add");
Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1"));

View File

@ -38,7 +38,7 @@ class ExhaustiveF32ElementwiseOpTest
XlaBuilder builder(TestName());
std::unique_ptr<Literal> input_literal =
Literal input_literal =
LiteralUtil::CreateFromDimensions(F32, {input_size});
for (int64 i = begin; i < end; i++) {
if (i >= known_incorrect_range.first &&

Some files were not shown because too many files have changed in this diff Show More